# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import warnings
import logging
import numpy as np
import scipy.signal as signal
from scipy.optimize import curve_fit
from functools import partial


def detect_putative_spikes(v, t, start=None, end=None, filter=10.0, dv_cutoff=20.0):
    """Perform initial detection of spikes and return their indexes.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    start : start of time window for spike detection (optional)
    end : end of time window for spike detection (optional)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    putative_spikes : numpy array of preliminary spike indexes
    """
    if not isinstance(v, np.ndarray):
        raise TypeError("v is not an np.ndarray")

    if not isinstance(t, np.ndarray):
        raise TypeError("t is not an np.ndarray")

    if v.shape != t.shape:
        raise FeatureError("Voltage and time series do not have the same dimensions")

    if start is None:
        start = t[0]

    if end is None:
        end = t[-1]
    start_index = find_time_index(t, start)
    end_index = find_time_index(t, end)
    v_window = v[start_index : end_index + 1]
    t_window = t[start_index : end_index + 1]

    dvdt = calculate_dvdt(v_window, t_window, filter)
    # Find positive-going crossings of dV/dt cutoff level
    putative_spikes = np.flatnonzero(
        np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1
    )
    if dvdt[0] > dv_cutoff:
        putative_spikes = np.insert(putative_spikes, 0, 0)

    if len(putative_spikes) <= 1:
        # Set back to original index space (not just window)
        return np.array(putative_spikes) + start_index

    # Only keep spike times if dV/dt has dropped all the way to zero between putative spikes
    putative_spikes = [putative_spikes[0]] + [
        s
        for i, s in enumerate(putative_spikes[1:])
        if np.any(dvdt[putative_spikes[i] : s] < 0)
    ]
    # Set back to original index space (not just window)
    return np.array(putative_spikes) + start_index


def find_peak_indexes(v, t, spike_indexes, end=None):
    """Find indexes of spike peaks.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of preliminary spike indexes
    end : end of time window for spike detection (optional)
    """

    if not end:
        end = t[-1]
    end_index = find_time_index(t, end)

    spks_and_end = np.append(spike_indexes, end_index)
    peak_indexes = [
        np.argmax(v[spk:next]) + spk
        for spk, next in zip(spks_and_end[:-1], spks_and_end[1:])
    ]
    return np.array(peak_indexes)


def filter_putative_spikes(
    v,
    t,
    spike_indexes,
    peak_indexes,
    min_height=2.0,
    min_peak=-30.0,
    filter=10.0,
    dvdt=None,
):
    """Filter out events that are unlikely to be spikes based on:
        * Voltage failing to go down between peak and the next spike's threshold
        * Height (threshold to peak)
        * Absolute peak level

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of preliminary spike indexes
    peak_indexes : numpy array of indexes of spike peaks
    min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
    min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    spike_indexes : numpy array of threshold indexes
    peak_indexes : numpy array of peak indexes
    """

    if not spike_indexes.size or not peak_indexes.size:
        return np.array([]), np.array([])

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    diff_mask = [
        np.any(dvdt[peak_ind:spike_ind] < 0)
        for peak_ind, spike_ind in zip(peak_indexes[:-1], spike_indexes[1:])
    ]
    peak_indexes = peak_indexes[np.array(diff_mask + [True])]
    spike_indexes = spike_indexes[np.array([True] + diff_mask)]

    peak_level_mask = v[peak_indexes] >= min_peak
    spike_indexes = spike_indexes[peak_level_mask]
    peak_indexes = peak_indexes[peak_level_mask]

    height_mask = (v[peak_indexes] - v[spike_indexes]) >= min_height
    spike_indexes = spike_indexes[height_mask]
    peak_indexes = peak_indexes[height_mask]

    return spike_indexes, peak_indexes


def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10.0, dvdt=None):
    """Find indexes of maximum upstroke of spike.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of preliminary spike indexes
    peak_indexes : numpy array of indexes of spike peaks
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    upstroke_indexes : numpy array of upstroke indexes

    """

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    upstroke_indexes = [
        np.argmax(dvdt[spike:peak]) + spike
        for spike, peak in zip(spike_indexes, peak_indexes)
    ]
    return np.array(upstroke_indexes)


def refine_threshold_indexes_based_on_third_derivative(
    v, t, peak_indexes, upstroke_indexes, filter=10.0, dvdt=None
):
    """Refine threshold detection of previously-found spikes. Simple code to handle too steep depolarisations in the beginning
    now too.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    peak_indexes : numpy array of indexes of action potential peaks
    upstroke_indexes : numpy array of indexes of spike upstrokes
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)


    Returns
    -------
    threshold_indexes : numpy array of threshold indexes
    """
    if not upstroke_indexes.size:
        return np.array([])

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)
    dvdt2 = calculate_dvdt(dvdt, t[:-1], filter)  # Second derivative of the voltage
    dvdt3 = calculate_dvdt(dvdt2, t[:-2], filter)  # Third derivative of the voltage

    threshold_indexes = []
    peak_indexes_and_start = np.append(np.array([0]), peak_indexes)
    for peak, upstroke in zip(peak_indexes_and_start[:-1], upstroke_indexes):
        thresh_index = np.argmax(
            dvdt3[upstroke : peak + np.argmin(v[peak:upstroke]) : -1]
        )
        threshold_indexes.append(upstroke - thresh_index)
    return np.array(threshold_indexes)


def refine_threshold_indexes_updated(
    v, t, upstroke_indexes, start=None, thresh_frac=0.05, filter=10.0, dvdt=None
):
    """Refine threshold detection of previously-found spikes. Simple code to handle too steep depolarisations in the beginning
    now too.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation)
    thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    threshold_indexes : numpy array of threshold indexes
    """
    if start is None:
        start = t[0]

    start_index = find_time_index(t, start)

    if not upstroke_indexes.size:
        return np.array([])

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    avg_upstroke = dvdt[upstroke_indexes].mean()
    target = avg_upstroke * thresh_frac

    upstrokes_and_start = np.append(np.array([start_index]), upstroke_indexes)
    threshold_indexes = []
    for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]):
        if upstk_prev == start_index and not upstk_prev == find_time_index(
            t, 0.1
        ):  # Too steep depolarisations
            threshold_indexes.append(upstk - np.argmin(dvdt[upstk:upstk_prev:-1]))
            continue
        potential_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target)
        if not potential_indexes.size:
            # couldn't find a matching value for threshold,
            # so just going to the start of the search interval
            threshold_indexes.append(upstk_prev)
        else:
            threshold_indexes.append(upstk - potential_indexes[0])

    return np.array(threshold_indexes)


def refine_threshold_indexes(
    v, t, upstroke_indexes, thresh_frac=0.05, filter=10.0, dvdt=None
):
    """Refine threshold detection of previously-found spikes.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation)
    thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    threshold_indexes : numpy array of threshold indexes
    """

    if not upstroke_indexes.size:
        return np.array([])

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    avg_upstroke = dvdt[upstroke_indexes].mean()
    target = avg_upstroke * thresh_frac

    upstrokes_and_start = np.append(np.array([0]), upstroke_indexes)
    threshold_indexes = []
    for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]):
        potential_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target)
        if not potential_indexes.size:
            # couldn't find a matching value for threshold,
            # so just going to the start of the search interval
            threshold_indexes.append(upstk_prev)
        else:
            threshold_indexes.append(upstk - potential_indexes[0])

    return np.array(threshold_indexes)


def check_thresholds_and_peaks(
    v,
    t,
    spike_indexes,
    peak_indexes,
    upstroke_indexes,
    end=None,
    max_interval=0.005,
    thresh_frac=0.05,
    filter=10.0,
    dvdt=None,
    tol=1.0,
):
    """Validate thresholds and peaks for set of spikes

    Check that peaks and thresholds for consecutive spikes do not overlap
    Spikes with overlapping thresholds and peaks will be merged.

    Check that peaks and thresholds for a given spike are not too far apart.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of spike indexes
    peak_indexes : numpy array of indexes of spike peaks
    upstroke_indexes : numpy array of indexes of spike upstrokes
    max_interval : maximum allowed time between start of spike and time of peak in sec (default 0.005)
    thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)
    tol : tolerance for returning to threshold in mV (optional, default 1)

    Returns
    -------
    spike_indexes : numpy array of modified spike indexes
    peak_indexes : numpy array of modified spike peak indexes
    upstroke_indexes : numpy array of modified spike upstroke indexes
    clipped : numpy array of clipped status of spikes
    """

    if not end:
        end = t[-1]
    overlaps = np.flatnonzero(spike_indexes[1:] <= peak_indexes[:-1] + 1)
    if overlaps.size:
        spike_mask = np.ones_like(spike_indexes, dtype=bool)
        spike_mask[overlaps + 1] = False
        spike_indexes = spike_indexes[spike_mask]

        peak_mask = np.ones_like(peak_indexes, dtype=bool)
        peak_mask[overlaps] = False
        peak_indexes = peak_indexes[peak_mask]

        upstroke_mask = np.ones_like(upstroke_indexes, dtype=bool)
        upstroke_mask[overlaps] = False
        upstroke_indexes = upstroke_indexes[upstroke_mask]

    # Validate that peaks don't occur too long after the threshold
    # If they do, try to re-find threshold from the peak
    too_long_spikes = []
    for i, (spk, peak) in enumerate(zip(spike_indexes, peak_indexes)):
        if t[peak] - t[spk] >= max_interval:
            logging.info(
                "Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format(
                    max_interval
                )
            )
            too_long_spikes.append(i)

    if too_long_spikes:
        if dvdt is None:
            dvdt = calculate_dvdt(v, t, filter)
        avg_upstroke = dvdt[upstroke_indexes].mean()
        target = avg_upstroke * thresh_frac
        drop_spikes = []
        for i in too_long_spikes:
            # First guessing that threshold is wrong and peak is right
            peak = peak_indexes[i]
            t_0 = find_time_index(t, t[peak] - max_interval)
            below_target = np.flatnonzero(
                dvdt[upstroke_indexes[i] : t_0 : -1] <= target
            )
            if not below_target.size:
                # Now try to see if threshold was right but peak was wrong

                # Find the peak in a window twice the size of our allowed window
                spike = spike_indexes[i]
                t_0 = find_time_index(t, t[spike] + 2 * max_interval)
                new_peak = np.argmax(v[spike:t_0]) + spike

                # If that peak is okay (not outside the allowed window, not past the next spike)
                # then keep it
                if t[new_peak] - t[spike] < max_interval and (
                    i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]]
                ):
                    peak_indexes[i] = new_peak
                else:
                    # Otherwise, log and get rid of the spike
                    logging.info(
                        "Could not redetermine threshold-peak pair - dropping that pair"
                    )
                    drop_spikes.append(i)
            #                     raise FeatureError("Could not redetermine threshold")
            else:
                spike_indexes[i] = upstroke_indexes[i] - below_target[0]

        if drop_spikes:
            spike_indexes = np.delete(spike_indexes, drop_spikes)
            peak_indexes = np.delete(peak_indexes, drop_spikes)
            upstroke_indexes = np.delete(upstroke_indexes, drop_spikes)

    # Check that last spike was not cut off too early by end of stimulus
    # by checking that the membrane potential returned to at least the threshold
    # voltage - otherwise, drop it
    clipped = np.zeros_like(spike_indexes, dtype=bool)
    end_index = find_time_index(t, end)
    if len(spike_indexes) > 0 and not np.any(
        v[peak_indexes[-1] : end_index + 1] <= v[spike_indexes[-1]] + tol
    ):
        logging.debug(
            "Failed to return to threshold voltage + tolerance (%.2f) after last spike (min %.2f) - marking last spike as clipped",
            v[spike_indexes[-1]] + tol,
            v[peak_indexes[-1] : end_index + 1].min(),
        )
        clipped[-1] = True
    return spike_indexes, peak_indexes, upstroke_indexes, clipped


def check_threshold_w_peak(v, t, spike_indexes, peak_indexes, clipped):
    # Noisy 'spikes' which have a threshold way too close to the peak index should be removed. Sometimes lowering the filter helps,
    # you don't want to not see spikes which are spikes. This is an alternative approach.

    if not spike_indexes.size or not peak_indexes.size:
        return np.array([]), np.array([]), np.zeros_like(np.array([]), dtype=bool)

    # If the peak comes already less than approximately a fifth of a millisecond after the putative threshold, it's probably noise
    indices = [
        ((t[peak] - t[spike]) > 0.0002)
        for peak, spike in zip(peak_indexes, spike_indexes)
    ]
    spike_indexes = spike_indexes[indices]
    peak_indexes = peak_indexes[indices]
    clipped = clipped[indices]

    return spike_indexes, peak_indexes, clipped


def find_trough_indexes(v, t, spike_indexes, peak_indexes, clipped=None, end=None):
    """
    Find indexes of minimum voltage (trough) between spikes.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of spike indexes
    peak_indexes : numpy array of spike peak indexes
    end : end of time window (optional)

    Returns
    -------
    trough_indexes : numpy array of threshold indexes
    """

    if not spike_indexes.size or not peak_indexes.size:
        return np.array([])

    if clipped is None:
        clipped = np.zeros_like(spike_indexes, dtype=bool)

    if end is None:
        end = t[-1]
    end_index = find_time_index(t, end)

    trough_indexes = np.zeros_like(spike_indexes, dtype=float)

    trough_indexes[:-1] = [
        v[peak:spk].argmin() + peak
        for peak, spk in zip(peak_indexes[:-1], spike_indexes[1:])
    ]

    if clipped[-1]:
        # If last spike is cut off by the end of the window, trough is undefined
        trough_indexes[-1] = np.nan
    else:
        trough_indexes[-1] = v[peak_indexes[-1] : end_index].argmin() + peak_indexes[-1]

    # nwg - trying to remove this next part for now - can't figure out if this will be needed with new "clipped" method

    # If peak is the same point as the trough, drop that point
    # trough_indexes = trough_indexes[np.where(peak_indexes[:len(trough_indexes)] != trough_indexes)]
    return trough_indexes


def check_trough_w_peak(
    spike_indexes,
    upstroke_indexes,
    peak_indexes,
    trough_indexes,
    clipped,
    filter=10.0,
    dvdt=None,
):
    # Sometimes the stimulus at offset e.g. can result in a bump which could be classified as a spike. Sometimes lowering the filter can
    # work, but you would have to lower it a lot sometimes which might result in a lack of actual spike detection. It seems that for these
    # bumps the voltage does not increase much and that the trough and peak are at the same detected time index. Ofcourse a peak and
    # trough cannot occur at the same time, so based on this we remove those 'spikes'.

    if not spike_indexes.size or not peak_indexes.size:
        return (
            np.array([]),
            np.array([]),
            np.array([]),
            np.array([]),
            np.zeros_like(np.array([]), dtype=bool),
        )

    indices = [peak != trough for peak, trough in zip(peak_indexes, trough_indexes)]
    spike_indexes = spike_indexes[indices]
    upstroke_indexes = upstroke_indexes[indices]
    peak_indexes = peak_indexes[indices]
    trough_indexes = trough_indexes[indices]
    clipped = clipped[indices]

    return spike_indexes, upstroke_indexes, peak_indexes, trough_indexes, clipped


def find_downstroke_indexes(
    v, t, peak_indexes, trough_indexes, clipped=None, filter=10.0, dvdt=None
):
    """Find indexes of minimum voltage (troughs) between spikes.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    peak_indexes : numpy array of spike peak indexes
    trough_indexes : numpy array of threshold indexes
    clipped: boolean array - False if spike not clipped by edge of window
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    downstroke_indexes : numpy array of downstroke indexes
    """

    if not trough_indexes.size:
        return np.array([])

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    if clipped is None:
        clipped = np.zeros_like(peak_indexes, dtype=bool)

    if len(peak_indexes) < len(trough_indexes):
        raise FeatureError("Cannot have more troughs than peaks")
    # Taking this out...with clipped info, should always have the same number of points
    #     peak_indexes = peak_indexes[:len(trough_indexes)]

    valid_peak_indexes = peak_indexes[~clipped].astype(int)
    valid_trough_indexes = trough_indexes[~clipped].astype(int)
    downstroke_indexes = np.zeros_like(peak_indexes) * np.nan
    downstroke_index_values = [
        np.argmin(dvdt[peak:trough]) + peak
        for peak, trough in zip(valid_peak_indexes, valid_trough_indexes)
    ]
    downstroke_indexes[~clipped] = downstroke_index_values

    return downstroke_indexes


def find_widths(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None):
    """Find widths at half-height for spikes.

    Widths are only returned when heights are defined

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of spike indexes
    peak_indexes : numpy array of spike peak indexes
    trough_indexes : numpy array of trough indexes

    Returns
    -------
    widths : numpy array of spike widths in sec
    """

    if not spike_indexes.size or not peak_indexes.size:
        return np.array([])

    if len(spike_indexes) < len(trough_indexes):
        raise FeatureError("Cannot have more troughs than spikes")

    if clipped is None:
        clipped = np.zeros_like(spike_indexes, dtype=bool)

    use_indexes = ~np.isnan(trough_indexes)
    use_indexes[clipped] = False

    heights = np.zeros_like(trough_indexes) * np.nan
    heights[use_indexes] = (
        v[peak_indexes[use_indexes]] - v[trough_indexes[use_indexes].astype(int)]
    )

    width_levels = np.zeros_like(trough_indexes) * np.nan
    width_levels[use_indexes] = (
        heights[use_indexes] / 2.0 + v[trough_indexes[use_indexes].astype(int)]
    )

    thresh_to_peak_levels = np.zeros_like(trough_indexes) * np.nan
    thresh_to_peak_levels[use_indexes] = (
        v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes]]
    ) / 2.0 + v[spike_indexes[use_indexes]]

    # Some spikes in burst may have deep trough but short height, so can't use same
    # definition for width
    width_levels[width_levels < v[spike_indexes]] = thresh_to_peak_levels[
        width_levels < v[spike_indexes]
    ]

    width_starts = np.zeros_like(trough_indexes) * np.nan
    width_starts[use_indexes] = np.array(
        [
            pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0]
            if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0
            else np.nan
            for pk, spk, wl in zip(
                peak_indexes[use_indexes],
                spike_indexes[use_indexes],
                width_levels[use_indexes],
            )
        ]
    )
    width_ends = np.zeros_like(trough_indexes) * np.nan

    width_ends[use_indexes] = np.array(
        [
            pk + np.flatnonzero(v[pk:tr] <= wl)[0]
            if np.flatnonzero(v[pk:tr] <= wl).size > 0
            else np.nan
            for pk, tr, wl in zip(
                peak_indexes[use_indexes],
                trough_indexes[use_indexes].astype(int),
                width_levels[use_indexes],
            )
        ]
    )

    missing_widths = np.isnan(width_starts) | np.isnan(width_ends)
    widths = np.zeros_like(width_starts, dtype=np.float64)
    widths[~missing_widths] = (
        t[width_ends[~missing_widths].astype(int)]
        - t[width_starts[~missing_widths].astype(int)]
    )
    if any(missing_widths):
        widths[missing_widths] = np.nan

    return widths


def find_widths_wrt_threshold(
    v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None
):
    """Find widths at half-height for spikes but based this time on heights w.r.t. to the thresholds (not troughs)
    Trough_indexes are still necessary to find the index for which the AP reaches the half-height after the peak
    Widths are only returned when heights are defined

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of spike indexes (they are actually the thresholds)
    peak_indexes : numpy array of spike peak indexes
    trough_indexes : numpy array of trough indexes

    Returns
    -------
    widths : numpy array of spike widths in sec
    """
    if not spike_indexes.size or not peak_indexes.size:
        return np.array([])

    if len(spike_indexes) < len(trough_indexes):
        raise FeatureError("Cannot have more troughs than spikes")

    if clipped is None:
        clipped = np.zeros_like(spike_indexes, dtype=bool)
    use_indexes = ~np.isnan(trough_indexes)
    use_indexes[clipped] = False

    heights = np.zeros_like(trough_indexes) * np.nan
    heights[use_indexes] = (
        v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes].astype(int)]
    )

    width_levels = np.zeros_like(trough_indexes) * np.nan
    width_levels[use_indexes] = (
        heights[use_indexes] / 2.0 + v[spike_indexes[use_indexes].astype(int)]
    )

    width_starts = np.zeros_like(trough_indexes) * np.nan
    width_starts[use_indexes] = np.array(
        [
            pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0]
            if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0
            else np.nan
            for pk, spk, wl in zip(
                peak_indexes[use_indexes],
                spike_indexes[use_indexes],
                width_levels[use_indexes],
            )
        ]
    )
    width_ends = np.zeros_like(trough_indexes) * np.nan
    width_ends[use_indexes] = np.array(
        [
            pk + np.flatnonzero(v[pk:tr] <= wl)[0]
            if np.flatnonzero(v[pk:tr] <= wl).size > 0
            else np.nan
            for pk, tr, wl in zip(
                peak_indexes[use_indexes],
                trough_indexes[use_indexes].astype(int),
                width_levels[use_indexes],
            )
        ]
    )

    missing_widths = np.isnan(width_starts) | np.isnan(width_ends)
    widths = np.zeros_like(width_starts, dtype=np.float64)
    widths[~missing_widths] = (
        t[width_ends[~missing_widths].astype(int)]
        - t[width_starts[~missing_widths].astype(int)]
    )
    if any(missing_widths):
        widths[missing_widths] = np.nan

    return widths


def analyze_trough_details(
    v,
    t,
    spike_indexes,
    peak_indexes,
    clipped=None,
    end=None,
    filter=10.0,
    heavy_filter=1.0,
    term_frac=0.003,
    adp_thresh=0.5,
    tol=0.5,
    flat_interval=0.002,
    adp_max_delta_t=0.01,
    adp_max_delta_v=10.0,
    dvdt=None,
):
    """Analyze trough to determine if an ADP exists and whether the reset is a 'detour' or 'direct'

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    spike_indexes : numpy array of spike indexes
    peak_indexes : numpy array of spike peak indexes
    end : end of time window (optional)
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1)
    heavy_filter : lower cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1)
    thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
    adp_thresh: minimum dV/dt in V/s to exceed to be considered to have an ADP (optional, default 1.5)
    tol : tolerance for evaluating whether Vm drops appreciably further after end of spike (default 1.0 mV)
    flat_interval: if the trace is flat for this duration, stop looking for an ADP (default 0.002 s)
    adp_max_delta_t: max possible ADP delta t (default 0.005 s)
    adp_max_delta_v: max possible ADP delta v (default 10 mV)
    dvdt : pre-calculated time-derivative of voltage (optional)

    Returns
    -------
    isi_types : numpy array of isi reset types (direct or detour)
    fast_trough_indexes : numpy array of indexes at the start of the trough (i.e. end of the spike)
    adp_indexes : numpy array of adp indexes (np.nan if there was no ADP in that ISI
    slow_trough_indexes : numpy array of indexes at the minimum of the slow phase of the trough
                          (if there wasn't just a fast phase)
    """

    if end is None:
        end = t[-1]
    end_index = find_time_index(t, end)

    if clipped is None:
        clipped = np.zeros_like(peak_indexes, dtype=bool)

    # Can't evaluate for spikes that are clipped by the window
    orig_len = len(peak_indexes)
    valid_spike_indexes = spike_indexes[~clipped]
    valid_peak_indexes = peak_indexes[~clipped]

    if dvdt is None:
        dvdt = calculate_dvdt(v, t, filter)

    dvdt_hvy = calculate_dvdt(v, t, heavy_filter)

    # Writing as for loop - see if I can vectorize any later
    fast_trough_indexes = []
    adp_indexes = []
    slow_trough_indexes = []
    isi_types = []

    update_clipped = []
    for peak, next_spk in zip(
        valid_peak_indexes, np.append(valid_spike_indexes[1:], end_index)
    ):
        downstroke = dvdt[peak:next_spk].argmin() + peak
        target = term_frac * dvdt[downstroke]

        terminated_points = np.flatnonzero(dvdt[downstroke:next_spk] >= target)
        if terminated_points.size:
            terminated = terminated_points[0] + downstroke
            update_clipped.append(False)
        else:
            logging.debug("Could not identify fast trough - marking spike as clipped")
            isi_types.append(np.nan)
            fast_trough_indexes.append(np.nan)
            adp_indexes.append(np.nan)
            slow_trough_indexes.append(np.nan)
            update_clipped.append(True)
            continue

        # Could there be an ADP?
        adp_index = np.nan
        dv_over_thresh = np.flatnonzero(dvdt_hvy[terminated:next_spk] >= adp_thresh)
        if dv_over_thresh.size:
            cross = dv_over_thresh[0] + terminated

            # only want to look for ADP before things get pretty flat
            # otherwise, could just pick up random transients long after the spike
            if t[cross] - t[terminated] < flat_interval:
                # Going back up fast, but could just be going into another spike
                # so need to check for a reversal (zero-crossing) in dV/dt
                zero_return_vals = np.flatnonzero(dvdt_hvy[cross:next_spk] <= 0)
                if zero_return_vals.size:
                    putative_adp_index = zero_return_vals[0] + cross
                    min_index = (
                        v[putative_adp_index:next_spk].argmin() + putative_adp_index
                    )
                    if (
                        v[putative_adp_index] - v[min_index] >= tol
                        and v[putative_adp_index] - v[terminated] <= adp_max_delta_v
                        and t[putative_adp_index] - t[terminated] <= adp_max_delta_t
                    ):
                        adp_index = putative_adp_index
                        slow_phase_min_index = min_index
                        isi_type = "detour"

        if np.isnan(adp_index):
            v_term = v[terminated]
            min_index = v[terminated:next_spk].argmin() + terminated
            if v_term - v[min_index] >= tol:
                # dropped further after end of spike -> detour reset
                isi_type = "detour"
                slow_phase_min_index = min_index
            else:
                isi_type = "direct"

        isi_types.append(isi_type)
        fast_trough_indexes.append(terminated)
        adp_indexes.append(adp_index)
        if isi_type == "detour":
            slow_trough_indexes.append(slow_phase_min_index)
        else:
            slow_trough_indexes.append(np.nan)

    # If we had to kick some spikes out before, need to add nans at the end
    output = []
    output.append(np.array(isi_types))
    for d in (fast_trough_indexes, adp_indexes, slow_trough_indexes):
        output.append(np.array(d, dtype=float))

    if orig_len > len(isi_types):
        extra = np.zeros(orig_len - len(isi_types)) * np.nan
        output = tuple((np.append(o, extra) for o in output))

    # The ADP and slow trough for the last spike in a train are not reliably
    # calculated, and usually extreme when wrong, so we will NaN them out.
    #
    # Note that this will result in a 0 value when delta V or delta T is
    # calculated, which may not be strictly accurate to the trace, but the
    # magnitude of the difference will be less than in many of the erroneous
    # cases seen otherwise

    output[2][-1] = np.nan  # ADP
    output[3][-1] = np.nan  # slow trough

    clipped[~clipped] = update_clipped
    return output, clipped


def find_time_index(t, t_0):
    """Find the index value of a given time (t_0) in a time series (t)."""

    t_gte = np.flatnonzero(t >= t_0)
    if not t_gte.size:
        raise FeatureError("Could not find given time in time vector")

    return t_gte[0]


def calculate_dvdt(v, t, filter=None):
    """Low-pass filters (if requested) and differentiates voltage by time.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default None)

    Returns
    -------
    dvdt : numpy array of time-derivative of voltage (V/s = mV/ms)
    """

    if has_fixed_dt(t) and filter:
        delta_t = t[1] - t[0]
        sample_freq = 1.0 / delta_t
        filt_coeff = (filter * 1e3) / (
            sample_freq / 2.0
        )  # filter kHz -> Hz, then get fraction of Nyquist frequency
        if filt_coeff < 0 or filt_coeff >= 1:
            raise ValueError(
                "bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(
                    filt_coeff, sample_freq / 1e3, filter
                )
            )
        b, a = signal.bessel(4, filt_coeff, "low")
        v_filt = signal.filtfilt(b, a, v, axis=0)
        dv = np.diff(v_filt)
    else:
        dv = np.diff(v)

    dt = np.diff(t)
    dvdt = 1e-3 * dv / dt  # in V/s = mV/ms

    # Remove nan values (in case any dt values == 0)
    dvdt = dvdt[~np.isnan(dvdt)]

    return dvdt


def get_isis(t, spikes):
    """Find interspike intervals in sec between spikes (as indexes)."""

    if len(spikes) <= 1:
        return np.array([])

    return t[spikes[1:]] - t[spikes[:-1]]


def average_voltage(v, t, start=None, end=None):
    """Calculate average voltage between start and end.

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    start : start of time window for spike detection (optional, default None)
    end : end of time window for spike detection (optional, default None)

    Returns
    -------
    v_avg : average voltage
    """

    if start is None:
        start = t[0]

    if end is None:
        end = t[-1]
    start_index = find_time_index(t, start)
    end_index = find_time_index(t, end)

    return v[start_index:end_index].mean()


def adaptation_index(isis):
    """Calculate adaptation index of `isis`."""
    if len(isis) == 0:
        return np.nan

    return norm_diff(isis)


def latency(t, spikes, start):
    """Calculate time to the first spike."""

    if len(spikes) == 0:
        return np.nan

    if start is None:
        start = t[0]

    return t[spikes[0]] - start


def average_rate(t, spikes, start, end):
    """Calculate average firing rate during interval between `start` and `end`.

    Parameters
    ----------
    t : numpy array of times in seconds
    spikes : numpy array of spike indexes
    start : start of time window for spike detection
    end : end of time window for spike detection

    Returns
    -------
    avg_rate : average firing rate in spikes/sec
    """

    if start is None:
        start = t[0]

    if end is None:
        end = t[-1]

    spikes_in_interval = [spk for spk in spikes if t[spk] >= start and t[spk] <= end]
    avg_rate = len(spikes_in_interval) / (end - start)
    return avg_rate


def norm_diff(a):
    """Calculate average of (a[i] - a[i+1]) / (a[i] + a[i+1])."""

    if len(a) <= 1:
        return np.nan

    a = a.astype(float)
    if np.allclose((a[1:] + a[:-1]), 0.0):
        return 0.0
    norm_diffs = (a[1:] - a[:-1]) / (a[1:] + a[:-1])
    norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0.0
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
        avg = np.nanmean(norm_diffs)
    return avg


def norm_sq_diff(a):
    """Calculate average of (a[i+1] - a[i])^2 / (a[i] + a[i+1])^2."""
    if len(a) <= 1:
        return np.nan

    a = a.astype(float)
    norm_sq_diffs = np.square((a[1:] - a[:-1])) / np.square((a[1:] + a[:-1]))
    return norm_sq_diffs.mean()


def isi_adaptation(a):
    """Calculate average of (a[1:]/a[:-1])."""
    if len(a) <= 1:
        return np.nan

    a = a.astype(float)
    isis_changes = a[1:] / a[:-1]
    return isis_changes.mean()


def ap_amp_adaptation(a):
    """Calculate average of (a[1:]/a[:-1])."""
    if len(a) <= 1:
        return np.nan

    a = a.astype(float)
    ap_amp_changes = a[1:] / a[:-1]
    return ap_amp_changes.mean()


def has_fixed_dt(t):
    """Check that all time intervals are identical."""
    dt = np.diff(t)
    return np.allclose(dt, np.ones_like(dt) * dt[0])


def fit_membrane_time_constant(v, t, start, end, min_rsme=1e-4):
    """Fit an exponential to estimate membrane time constant between start and end

    Parameters
    ----------
    v : numpy array of voltages in mV
    t : numpy array of times in seconds
    start : start of time window for exponential fit
    end : end of time window for exponential fit
    min_rsme: minimal acceptable root mean square error (default 1e-4)

    Returns
    -------
    a, inv_tau, y0 : Coeffients of equation y0 + a * exp(-inv_tau * x)

    returns np.nan for values if fit fails
    """

    start_index = find_time_index(t, start)
    end_index = find_time_index(t, end)

    guess = (v[start_index] - v[end_index], 50.0, v[end_index])
    t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64)
    v_window = v[start_index:end_index].astype(np.float64)

    if v_window.size == 0:
        logging.info("Voltage window had no size")
        return np.nan, np.nan, np.nan

    try:
        popt, pcov = curve_fit(_exp_curve, t_window, v_window, p0=guess)
    except RuntimeError:
        logging.info("Curve fit for membrane time constant failed")
        return np.nan, np.nan, np.nan

    pred = _exp_curve(t_window, *popt)
    # print('pred: ', pred)
    # print('voltage: ', v_window)
    # print('mean difference: ', np.mean(pred - v_window))
    rsme = np.sqrt(np.abs(np.mean(pred - v_window)))
    if rsme > min_rsme:
        logging.debug("Curve fit for membrane time constant did not meet RSME standard")
        return np.nan, np.nan, np.nan

    return popt


def fit_membrane_time_constant_at_end(v, t, start, end, min_rsme=1e-4):
    """Fit an exponential to estimate membrane time constant between start and end

    Parameters
    ----------
    v : numpy array of voltages in mV
    t : numpy array of times in seconds
    start : start of time window for exponential fit
    end : end of time window for exponential fit
    min_rsme: minimal acceptable root mean square error (default 1e-4)

    Returns
    -------
    a, inv_tau, y0 : Coeffients of equation y0 + a * exp(-inv_tau * x)

    returns np.nan for values if fit fails
    """

    start_index = find_time_index(t, start)
    end_index = find_time_index(t, end)

    guess = (v[end_index] - v[start_index], 50.0, v[start_index])
    t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64)
    v_window = v[start_index:end_index].astype(np.float64)
    try:
        popt, pcov = curve_fit(_exp_curve_at_end, t_window, v_window, p0=guess)
    except RuntimeError:
        logging.info("Curve fit for membrane time constant failed")
        return np.nan, np.nan, np.nan

    pred = _exp_curve(t_window, *popt)
    rsme = np.sqrt(np.mean(pred - v_window))
    if rsme > min_rsme:
        logging.debug("Curve fit for membrane time constant did not meet RSME standard")
        return np.nan, np.nan, np.nan

    return popt


def detect_pauses(isis, isi_types, cost_weight=1.0):
    """Determine which ISIs are "pauses" in ongoing firing.

    Pauses are unusually long ISIs with a "detour reset" among "direct resets".

    Parameters
    ----------
    isis : numpy array of interspike intervals
    isi_types : numpy array of interspike interval types ('direct' or 'detour')
    cost_weight : weight for cost function for calling an ISI a pause
        Higher cost weights lead to fewer ISIs identified as pauses. The cost function
        also depends on the difference between the duration of the "pause" ISIs and the
        average duration and standard deviation of "non-pause" ISIs.

    Returns
    -------
    pauses : numpy array of indices corresponding to pauses in `isis`
    """

    if len(isis) != len(isi_types):
        raise FeatureError("Wrong number of ISIs")

    if not np.any(isi_types == "direct"):
        # Need some direct-type firing to have pauses
        return np.array([])

    detour_candidates = [
        i for i, isi_type in enumerate(isi_types) if isi_type == "detour"
    ]
    median_direct = np.median(isis[isi_types == "direct"])
    direct_candidates = [
        i
        for i, isi_type in enumerate(isi_types)
        if isi_type == "direct" and isis[i] > 3 * median_direct
    ]
    candidates = detour_candidates + direct_candidates
    # print('pause candidates: ', candidates)

    if not candidates:
        return np.array([])

    pause_list = np.array([], dtype=int)
    all_cv = isis.std() / isis.mean()
    best_net = 0
    for i in candidates:
        temp_pause_list = np.append(pause_list, i)
        non_pause_isis = np.delete(isis, temp_pause_list)
        pause_isis = isis[temp_pause_list]
        if len(non_pause_isis) < 2:
            break
        cv = non_pause_isis.std() / non_pause_isis.mean()
        benefit = all_cv - cv
        # print('benefit: ', benefit)
        # print('cause of warning? (=0?): ', np.abs(non_pause_isis.mean() - pause_isis)[np.abs(non_pause_isis.mean() - pause_isis)==0])
        if 0 in np.abs(non_pause_isis.mean() - pause_isis):
            # print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            cost = 1000  # just a huge cost
        else:
            cost = np.sum(
                non_pause_isis.std() / np.abs(non_pause_isis.mean() - pause_isis)
            )
        # print('cost: ', cost)
        cost *= cost_weight
        net = benefit - cost
        if net > 0 and net < best_net:
            break
        # print('net: ', net)
        # print('best_net: ', best_net)
        # print(net > best_net)
        if net > best_net:
            best_net = net
        pause_list = np.append(pause_list, i)

    if best_net <= 0:
        pause_list = np.array([])

    return np.sort(pause_list)


def detect_bursts(
    isis,
    isi_types,
    fast_tr_v,
    fast_tr_t,
    slow_tr_v,
    slow_tr_t,
    thr_v,
    tol=0.5,
    pause_cost=1.0,
):
    """Detect bursts in spike train.

    Parameters
    ----------
    isis : numpy array of n interspike intervals
    isi_types : numpy array of n interspike interval types
    fast_tr_v : numpy array of fast trough voltages for the n + 1 spikes of the train
    fast_tr_t : numpy array of fast trough times for the n + 1 spikes of the train
    slow_tr_v : numpy array of slow trough voltages for the n + 1 spikes of the train
    slow_tr_t : numpy array of slow trough times for the n + 1 spikes of the train
    thr_v : numpy array of threshold voltages for the n + 1 spikes of the train
    tol : tolerance for the difference in slow trough voltages and thresholds (default 0.5 mV)
        Used to identify "delay" interspike intervals that occur within a burst

    Returns
    -------
    bursts : list of bursts
        Each item in list is a tuple of the form (burst_index, start, end) where `burst_index`
        is a comparison index between the highest instantaneous rate within the burst vs
        the highest instantaneous rate outside the burst. `start` is the index of the first
        ISI of the burst, and `end` is the ISI index immediately following the burst.
    """

    if len(isis) != len(isi_types):
        raise FeatureError("Wrong number of ISIs")

    if len(isis) < 2:  # can't determine burstiness for a single ISI
        return np.array([])

    fast_tr_v = fast_tr_v[:-1]
    fast_tr_t = fast_tr_t[:-1]
    slow_tr_v = slow_tr_v[:-1]
    slow_tr_t = slow_tr_t[:-1]

    isi_types = np.array(isi_types)  # don't want to change the actual isi types data
    # print(isi_types)

    # Burst transitions can't be at "pause"-like ISIs
    pauses = detect_pauses(isis, isi_types, cost_weight=pause_cost).astype(int)
    isi_types[pauses] = "pauselike"

    if not (np.any(isi_types == "direct") and np.any(isi_types == "detour")):
        # no candidates that could be bursts
        return np.array([])

    # Want to catch special case of detour in the middle of a large burst where
    # the slow trough value is higher than the previous spike's threshold

    # WHERE IS THIS EVER USED????? --> prints warnings too
    # isi_types[(thr_v[:-1] < (slow_tr_v + tol)) & (isi_types == "detour")] = "midburst"

    # print(isi_types)
    # Find transitions from detour -> direct and vice versa for burst boundaries
    into_burst = np.array(
        [
            i + 1
            for i, (prev, cur) in enumerate(zip(isi_types[:-1], isi_types[1:]))
            if cur == "direct" and prev == "detour"
        ],
        dtype=int,
    )
    if isi_types[0] == "direct":
        # print('We"re going directly in a burst')
        into_burst = np.append(np.array([0]), into_burst)

    # print(into_burst)

    drop_into = []
    out_of_burst = []
    for j, (into, next) in enumerate(
        zip(into_burst, np.append(into_burst[1:], len(isis)))
    ):
        # print('j: ', j, 'into: ', into, 'next: ', next)
        for i, isi in enumerate(isi_types[into + 1 : next]):
            if isi == "detour":
                out_of_burst.append(i + into + 1)
                break
            elif isi == "pauselike":
                drop_into.append(j)
                break

    # Burst transitions can't be at pauselike transitions ==> delete here

    mask = np.ones_like(into_burst, dtype=bool)
    mask[drop_into] = False
    into_burst = into_burst[mask]

    out_of_burst = np.array(out_of_burst)
    if len(out_of_burst) == len(into_burst) - 1:
        out_of_burst = np.append(out_of_burst, len(isi_types))
    # print('into_burst: ', into_burst)
    # print('out_of_burst: ', out_of_burst)

    if not (into_burst.size or out_of_burst.size):
        return np.array([])

    if len(into_burst) != len(out_of_burst):
        raise FeatureError("Inconsistent burst boundary identification")

    inout_pairs = list(zip(into_burst, out_of_burst))
    # print('into_burst: ', into_burst)
    # print('out_of_burst: ', out_of_burst)
    delta_t = slow_tr_t - fast_tr_t

    # print('delta_t: ', delta_t)

    scores = _score_burst_set(inout_pairs, isis, delta_t)
    # print('Initial scores: ', scores)
    best_score = np.mean(scores)
    # print('best score: ', best_score)
    worst = np.argmin(scores)
    # print('index of worst score: ', worst)
    test_bursts = inout_pairs
    # print('test_bursts: ', test_bursts)
    del test_bursts[worst]
    while len(test_bursts) > 0:
        scores = _score_burst_set(test_bursts, isis, delta_t)
        if np.mean(scores) > best_score:
            best_score = np.mean(scores)
            inout_pairs = list(test_bursts)
            worst = np.argmin(scores)
            del test_bursts[worst]
        else:
            break

    # print('best score: ', best_score)
    if best_score < 0:
        return np.array([])

    bursts = []
    for i, (into, outof) in enumerate(inout_pairs):
        if i == len(inout_pairs) - 1:  # last burst to evaluate
            if outof <= len(isis) - 1:  # are there spikes left after the burst?
                metric = _burstiness_index(isis[into:outof], isis[outof:])
            elif i == 0:  # was this the first one (and there weren't spikes after)?
                metric = _burstiness_index(isis[into:outof], isis[:into])
            else:
                prev_burst = inout_pairs[i - 1]
                metric = _burstiness_index(isis[into:outof], isis[prev_burst[1] : into])
        else:  # So normally, we calculate a burst metric with the isis in and out (out is to the next burst!) of the burst
            next_burst = inout_pairs[i + 1]
            metric = _burstiness_index(isis[into:outof], isis[outof : next_burst[0]])
        bursts.append((metric, into, outof))

    return bursts


def fit_prespike_time_constant(v, t, start, spike_time, dv_limit=-0.001, tau_limit=0.3):
    """Finds the dominant time constant of the pre-spike rise in voltage

    Parameters
    ----------
    v : numpy array of voltage time series in mV
    t : numpy array of times in seconds
    start : start of voltage rise (seconds)
    spike_time : time of first spike (seconds)
    dv_limit : dV/dt cutoff (default -0.001)
        Shortens fit window if rate of voltage drop exceeds this limit
    tau_limit : upper bound for slow time constant (seconds, default 0.3)
        If the slower time constant of a double-exponential fit is twice that of the faster
        and exceeds this limit, the faster one will be considered the dominant one

    Returns
    -------
    tau : dominant time constant (seconds)
    """

    start_index = find_time_index(t, start)
    end_index = find_time_index(t, spike_time)
    if end_index <= start_index:
        raise FeatureError(
            "Start for pre-spike time constant fit cannot be after the spike time."
        )

    v_slice = v[start_index:end_index]
    t_slice = t[start_index:end_index]

    # Solve linear version with single exponential first to guess at the time constant
    y0 = v_slice.max() + 5e-6  # set y0 slightly above v_slice maximum
    y = -v_slice + y0
    y = np.log(y)

    dy = calculate_dvdt(y, t_slice, filter=1.0)

    # End the fit interval if the voltage starts dropping
    new_end_indexes = np.flatnonzero(dy <= dv_limit)
    cross_limit = 0.0005  # sec
    if (
        not new_end_indexes.size
        or t_slice[new_end_indexes[0]] - t_slice[0] < cross_limit
    ):
        # either never crosses or crosses too early
        new_end_index = len(v_slice)
    else:
        new_end_index = new_end_indexes[0]

    K, A_log = np.polyfit(t_slice[:new_end_index] - t_slice[0], y[:new_end_index], 1)
    A = np.exp(A_log)

    dbl_exp_y0 = partial(_dbl_exp_fit, y0)
    try:
        popt, pcov = curve_fit(
            dbl_exp_y0,
            t_slice - t_slice[0],
            v_slice,
            p0=(-A / 2.0, -1.0 / K, -A / 2.0, -1.0 / K),
        )
    except RuntimeError:
        # Fall back to single fit
        tau = -1.0 / K
        return tau

    # Find dominant time constant
    if popt[1] < popt[3]:
        faster_weight, faster_tau, slower_weight, slower_tau = popt
    else:
        slower_weight, slower_tau, faster_weight, faster_tau = popt

    # These are all empirical values
    if np.abs(faster_weight) > np.abs(slower_weight):
        tau = faster_tau
    elif (slower_tau - faster_tau) / slower_tau <= 0.1:  # close enough; just use slower
        tau = slower_tau
    elif slower_tau > tau_limit and slower_weight / faster_weight < 2.0:
        tau = faster_tau
    else:
        tau = slower_tau

    return tau


def estimate_adjusted_detection_parameters(
    v_set, t_set, interval_start, interval_end, filter=10
):
    """
    Estimate adjusted values for spike detection by analyzing a period when the voltage
    changes quickly but passively (due to strong current stimulation), which can result
    in spurious spike detection results.

    Parameters
    ----------
    v_set : list of numpy arrays of voltage time series in mV
    t_set : list of numpy arrays of times in seconds
    interval_start : start of analysis interval (sec)
    interval_end : end of analysis interval (sec)

    Returns
    -------
    new_dv_cutoff : adjusted dv/dt cutoff (V/s)
    new_thresh_frac : adjusted fraction of avg upstroke to find threshold
    """

    if type(v_set) is not list:
        v_set = list(v_set)

    if type(t_set) is not list:
        t_set = list(t_set)

    if len(v_set) != len(t_set):
        raise FeatureError("t_set and v_set must be lists of equal size")

    if len(v_set) == 0:
        raise FeatureError("t_set and v_set are empty")
    start_index = find_time_index(t_set[0], interval_start)
    end_index = find_time_index(t_set[0], interval_end)

    maxes = []
    ends = []
    dv_set = []
    for v, t in zip(v_set, t_set):
        dv = calculate_dvdt(v, t, filter)
        dv_set.append(dv)
        maxes.append(dv[start_index:end_index].max())
        ends.append(dv[end_index])

    maxes = np.array(maxes)
    ends = np.array(ends)

    cutoff_adj_factor = 1.1
    thresh_frac_adj_factor = 1.2

    new_dv_cutoff = np.median(maxes) * cutoff_adj_factor
    min_thresh = np.median(ends) * thresh_frac_adj_factor

    all_upstrokes = np.array([])
    for v, t, dv in zip(v_set, t_set, dv_set):
        putative_spikes = detect_putative_spikes(
            v,
            t,
            start=interval_start,
            end=interval_end,
            dv_cutoff=new_dv_cutoff,
            filter=filter,
        )
        peaks = find_peak_indexes(v, t, putative_spikes)
        putative_spikes, peaks = filter_putative_spikes(
            v, t, putative_spikes, peaks, dvdt=dv, filter=filter
        )
        upstrokes = find_upstroke_indexes(v, t, putative_spikes, peaks, dvdt=dv)
        if upstrokes.size:
            all_upstrokes = np.append(all_upstrokes, dv[upstrokes])
    new_thresh_frac = min_thresh / all_upstrokes.mean()

    return new_dv_cutoff, new_thresh_frac


def _score_burst_set(bursts, isis, delta_t, c_n=0.1, c_tx=0.01):
    in_burst = np.zeros_like(isis, dtype=bool)
    for b in bursts:
        in_burst[int(b[0]) : int(b[1])] = True

    # If all ISIs are part of a burst, give it a bad score
    if len(isis[~in_burst]) == 0:
        return [-1e12] * len(bursts)

    delta_frac = delta_t / isis

    scores = []
    for b in bursts:
        score = _burstiness_index(
            isis[int(b[0]) : int(b[1])], isis[~in_burst]
        )  # base score
        if b[1] < len(delta_t):
            score -= c_tx * (1.0 / (delta_frac[int(b[1])]))  # cost for ending a burst
        if b[0] > 0:
            score -= c_tx * (
                1.0 / delta_frac[int(b[0]) - 1]
            )  # cost for starting a burst
        score -= c_n * (b[1] - b[0] - 1)  # cost for extending a burst
        scores.append(score)

    return scores


def _burstiness_index(in_burst_isis, out_burst_isis):
    burst_rate = 1.0 / in_burst_isis.min()
    out_rate = 1.0 / out_burst_isis.min()
    return (burst_rate - out_rate) / (burst_rate + out_rate)


def _exp_curve(x, a, inv_tau, y0):
    return y0 + a * np.exp(-inv_tau * x)


def _exp_curve_at_end(x, b, inv_tau, A):
    return A + b * (1 - np.exp(-inv_tau * x))


def _dbl_exp_fit(y0, x, A1, tau1, A2, tau2):
    penalty = 0
    if tau1 < 0 or tau2 < 0:
        penalty = 1e6
    return y0 + A1 * np.exp(-x / tau1) + A2 * np.exp(-x / tau2) + penalty


class FeatureError(Exception):
    """Generic Python-exception-derived object raised by feature detection functions."""

    pass
