# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from pandas import DataFrame
import warnings
import logging
from collections import Counter

import wasser.ephys_features as ft
import six

# Constants for stimulus-specific analysis
RAMPS_START = 1.02
LONG_SQUARES_START = 1.02
LONG_SQUARES_END = 2.02
SHORT_SQUARES_WINDOW_START = 1.02
SHORT_SQUARES_WINDOW_END = 1.021
SHORT_SQUARE_TRIPLE_WINDOW_START = 2.02
SHORT_SQUARE_TRIPLE_WINDOW_END = 2.021


class EphysSweepFeatureExtractor:
    """Feature calculation for a sweep (voltage and/or current time series)."""

    def __init__(
        self,
        t=None,
        v=None,
        i=None,
        start=None,
        end=None,
        filter=10.0,
        dv_cutoff=20.0,
        max_interval=0.005,
        min_height=2.0,
        min_peak=-30.0,
        thresh_frac=0.05,
        baseline_interval=0.1,
        baseline_detect_thresh=0.3,
        id=None,
    ):
        """Initialize SweepFeatures object.

        Parameters
        ----------
        t : ndarray of times (seconds)
        v : ndarray of voltages (mV)
        i : ndarray of currents (pA)
        start : start of time window for feature analysis (optional)
        end : end of time window for feature analysis (optional)
        filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
        dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
        max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
        min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
        min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
        thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
        baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
        baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
        """
        self.id = id
        self.t = t
        self.v = v
        self.i = i
        self.start = start
        self.end = end
        self.filter = filter
        self.dv_cutoff = dv_cutoff
        self.max_interval = max_interval
        self.min_height = min_height
        self.min_peak = min_peak
        self.thresh_frac = thresh_frac
        self.baseline_interval = baseline_interval
        self.baseline_detect_thresh = baseline_detect_thresh
        self.stimulus_amplitude_calculator = None

        self._sweep_features = {}
        self._affected_by_clipping = []

    def process_spikes(self):
        """Perform spike-related feature analysis"""
        self._process_individual_spikes()
        self._process_spike_related_features()

    def _process_individual_spikes(self):
        v = self.v
        t = self.t
        dvdt = ft.calculate_dvdt(v, t, self.filter)

        # Basic features of spikes
        try:
            putative_spikes = ft.detect_putative_spikes(
                v, t, self.start, self.end, self.filter, self.dv_cutoff
            )
        except (
            IndexError
        ):  # This just means the simulation entirely failed and there's nothing to check
            self._spikes_df = DataFrame()
            return
        # print('put spikes: ', t[putative_spikes][:3])
        peaks = ft.find_peak_indexes(v, t, putative_spikes, self.end)
        putative_spikes, peaks = ft.filter_putative_spikes(
            v,
            t,
            putative_spikes,
            peaks,
            self.min_height,
            self.min_peak,
            dvdt=dvdt,
            filter=self.filter,
        )
        # print('put spikes after filter: ', t[putative_spikes][:3])
        # print('\w peaks: ', t[peaks][:3]),
        if not putative_spikes.size:
            # Save time if no spikes detected
            self._spikes_df = DataFrame()
            return

        upstrokes = ft.find_upstroke_indexes(
            v, t, putative_spikes, peaks, self.filter, dvdt
        )

        thresholds = ft.refine_threshold_indexes(
            v, t, upstrokes, filter=self.filter, dvdt=dvdt
        )

        if t[thresholds][0] < self.start:
            # print('The threshold was put before start of stimulation @ ', t[thresholds][0])
            thresholds_ = ft.refine_threshold_indexes_based_on_third_derivative(
                v, t, peaks, upstrokes, filter=10.0, dvdt=dvdt
            )
            # print('Calculated a new threshold @ ', t[thresholds_][0])
            if t[thresholds_][0] > t[thresholds][0]:
                thresholds = thresholds_

        if t[thresholds][0] > t[peaks][0]:
            # Save time if, currently the next peak index comes before the threshold index
            self._spikes_df = DataFrame()
            return

        thresholds, peaks, upstrokes, clipped = ft.check_thresholds_and_peaks(
            v,
            t,
            thresholds,
            peaks,
            upstrokes,
            end=self.end,
            max_interval=self.max_interval,
            filter=self.filter,
        )
        if not thresholds.size:
            # Save time if no spikes detected
            self._spikes_df = DataFrame()
            return

        if t[thresholds][0] > t[peaks][0]:
            # Save time if, currently the next peak index comes before the threshold index
            self._spikes_df = DataFrame()
            return

        # Spike list and thresholds have been refined - now find other features
        while True:
            try:
                upstrokes = ft.find_upstroke_indexes(
                    v, t, thresholds, peaks, self.filter, dvdt
                )
                break
            except (
                ValueError
            ):  # Peaks and threshold indexes didn't make sense and you'd try to find a nonsense upstroke
                self._spikes_df = DataFrame()
                return

        troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end)

        # You can comment the following lines if you don't want to check this
        thresholds, upstrokes, peaks, troughs, clipped = ft.check_trough_w_peak(
            thresholds, upstrokes, peaks, troughs, clipped, filter=10.0, dvdt=None
        )

        # Maybe you have nothing anymore so save time by the following:
        if not thresholds.size:
            self._spikes_df = DataFrame()
            return

        downstrokes = ft.find_downstroke_indexes(
            v, t, peaks, troughs, clipped, self.filter, dvdt
        )
        trough_details, clipped = ft.analyze_trough_details(
            v, t, thresholds, peaks, clipped, self.end, self.filter, dvdt=dvdt
        )
        widths = ft.find_widths_wrt_threshold(
            v, t, thresholds, peaks, trough_details[1], clipped
        )

        base_clipped_list = []

        # Points where we care about t, v, and i if available
        vit_data_indexes = {
            "threshold": thresholds,
            "peak": peaks,
            "trough": troughs,
        }
        base_clipped_list += ["trough"]

        # Points where we care about t and dv/dt
        dvdt_data_indexes = {"upstroke": upstrokes, "downstroke": downstrokes}
        base_clipped_list += ["downstroke"]

        # Trough details
        isi_types = trough_details[0]
        trough_detail_indexes = dict(
            zip(["fast_trough", "adp", "slow_trough"], trough_details[1:])
        )
        base_clipped_list += ["fast_trough", "adp", "slow_trough"]

        # Redundant, but ensures that DataFrame has right number of rows
        # Any better way to do it?
        spikes_df = DataFrame(data=thresholds, columns=["threshold_index"])
        spikes_df["clipped"] = clipped

        for k, all_vals in six.iteritems(vit_data_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k + "_t"] = np.nan
            spikes_df[k + "_v"] = np.nan

            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]

            if self.i is not None:
                spikes_df[k + "_i"] = np.nan
                if len(vals) > 0:
                    spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]

            if k in base_clipped_list:
                self._affected_by_clipping += [
                    k + "_index",
                    k + "_t",
                    k + "_v",
                    k + "_i",
                ]

        for k, all_vals in six.iteritems(dvdt_data_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k] = np.nan
            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]
                spikes_df.loc[valid_ind, k] = dvdt[vals]

                if k in base_clipped_list:
                    self._affected_by_clipping += [
                        k + "_index",
                        k + "_t",
                        k + "_v",
                        k,
                    ]

        spikes_df["isi_type"] = isi_types
        self._affected_by_clipping += ["isi_type"]

        for k, all_vals in six.iteritems(trough_detail_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k + "_t"] = np.nan
            spikes_df[k + "_v"] = np.nan
            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]

            if self.i is not None:
                spikes_df[k + "_i"] = np.nan
                if len(vals) > 0:
                    spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]

            if k in base_clipped_list:
                self._affected_by_clipping += [
                    k + "_index",
                    k + "_t",
                    k + "_v",
                    k + "_i",
                ]

        spikes_df["width"] = widths
        self._affected_by_clipping += ["width"]

        spikes_df["upstroke_downstroke_ratio"] = (
            spikes_df["upstroke"] / -spikes_df["downstroke"]
        )
        self._affected_by_clipping += ["upstroke_downstroke_ratio"]

        self._spikes_df = spikes_df

    def _process_spike_related_features(self):
        t = self.t

        if len(self._spikes_df) == 0:
            self._sweep_features["avg_rate"] = 0
            return

        # Start recently added
        peak_heights = None
        if not self._spikes_df.empty:
            peak_heights = (
                self._spikes_df["peak_v"].values - self._spikes_df["threshold_v"].values
            )
        # End recently added

        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(t, thresholds)
        with warnings.catch_warnings():
            # ignore mean of empty slice warnings here
            warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")

            sweep_level_features = {
                "adapt": ft.adaptation_index(isis),
                "latency": ft.latency(t, thresholds, self.start),
                "isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan,
                "mean_isi": isis.mean() if len(isis) > 0 else np.nan,
                "median_isi": np.median(isis),
                "first_isi": isis[0] if len(isis) >= 1 else np.nan,
                # We want at least 3 peaks (i.e. 2 isis) to calculate the adaptation index (given in percentage)
                "isi_adapt": (isis[1] / isis[0]) if len(isis) >= 2 else np.nan,
                # Start recently added
                # "AP_amp_adapt": self._spikes_df['peak_height'][1]/self._spikes_df['peak_height'][0] if self._spikes_df.shape[1] >= 2 else np.nan,
                "AP_amp_adapt": (peak_heights[1] / peak_heights[0])
                if peak_heights.size >= 2
                else np.nan,
                # "AP_amp_change": ft.ap_amp_change(self._spikes_df['peak_height'].values) if self._spikes_df.shape.shape[1] >= 2 else np.nan,
                "AP_amp_adapt_average": ft.ap_amp_adaptation(peak_heights)
                if peak_heights.size >= 2
                else np.nan,
                # End recently added
                "AP_fano_factor": ((peak_heights.std() ** 2) / peak_heights.mean())
                if peak_heights.size >= 2
                else np.nan,
                "AP_cv": ((peak_heights.std()) / peak_heights.mean())
                if peak_heights.size >= 2
                else np.nan,
                "isi_adapt_average": ft.isi_adaptation(isis)
                if len(isis) >= 2
                else np.nan,
                # "norm_sq_isis": ft.norm_sq_diff(isis) if len(isis) >= 2 else np.nan,
                # You could in principle make the Fano factor and cv 0 for n = 1 ISI, but we choose to make them Nan, i.e. they
                # are not that informative here
                "fano_factor": ((isis.std() ** 2) / isis.mean())
                if len(isis) > 1
                else np.nan,
                "cv": (isis.std() / isis.mean()) if len(isis) > 1 else np.nan,
                "avg_rate": ft.average_rate(t, thresholds, self.start, self.end),
            }

        for k, v in six.iteritems(sweep_level_features):
            self._sweep_features[k] = v

    def _process_pauses(self, cost_weight=1.0):
        # Pauses are unusually long ISIs with a "detour reset" among delay resets
        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)
        isi_types = self._spikes_df["isi_type"][:-1].values

        return ft.detect_pauses(isis, isi_types, cost_weight)

    def pause_metrics(self):
        """Estimate average number of pauses and average fraction of time spent in a pause

        Attempts to detect pauses with a variety of conditions and averages results together.

        Pauses that are consistently detected contribute more to estimates.

        Returns
        -------
        avg_n_pauses : average number of pauses detected across conditions
        avg_pause_frac : average fraction of interval (between start and end) spent in a pause
        max_reliability : max fraction of times most reliable pause was detected given weights tested
        n_max_rel_pauses : number of pauses detected with `max_reliability`
        """

        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)

        weight = 1.0
        pause_list = self._process_pauses(weight)

        if len(pause_list) == 0:
            return 0, 0.0

        n_pauses = len(pause_list)
        pause_frac = isis[pause_list].sum()
        pause_frac /= self.end - self.start

        return n_pauses, pause_frac

    def _process_bursts(self, tol=0.5, pause_cost=1.0):
        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)

        isi_types = self._spikes_df["isi_type"][:-1].values

        fast_tr_v = self._spikes_df["fast_trough_v"].values
        fast_tr_t = self._spikes_df["fast_trough_t"].values
        slow_tr_v = self._spikes_df["slow_trough_v"].values
        slow_tr_t = self._spikes_df["slow_trough_t"].values
        thr_v = self._spikes_df["threshold_v"].values

        bursts = ft.detect_bursts(
            isis,
            isi_types,
            fast_tr_v,
            fast_tr_t,
            slow_tr_v,
            slow_tr_t,
            thr_v,
            tol,
            pause_cost,
        )

        return np.array(bursts)

    def burst_metrics(self):
        """Find bursts and return max "burstiness" index (normalized max rate in burst vs out).

        Returns
        -------
        max_burstiness_index : max "burstiness" index across detected bursts
        num_bursts : number of bursts detected
        """

        burst_info = self._process_bursts()

        if burst_info.shape[0] > 0:
            return burst_info[:, 0].max(), burst_info.shape[0]
        else:
            return 0.0, 0

    def delay_metrics(self):
        """Calculates ratio of latency to dominant time constant of rise before spike

        Returns
        -------
        delay_ratio : ratio of latency to tau (higher means more delay)
        tau : dominant time constant of rise before spike
        """

        if len(self._spikes_df) == 0:
            logging.info("No spikes available for delay calculation")
            return 0.0, 0.0
        start = self.start
        spike_time = self._spikes_df["threshold_t"].values[0]

        tau = ft.fit_prespike_time_constant(self.v, self.t, start, spike_time)
        latency = spike_time - start

        delay_ratio = latency / tau
        return delay_ratio, tau

    def _get_baseline_voltage(self):
        v = self.v
        t = self.t
        filter_frequency = 1.0  # in kHz

        # Look at baseline interval before start if start is defined
        if self.start is not None:
            return ft.average_voltage(
                v, t, self.start - self.baseline_interval, self.start
            )

        # Otherwise try to find an interval where things are pretty flat
        dv = ft.calculate_dvdt(v, t, filter_frequency)
        non_flat_points = np.flatnonzero(np.abs(dv >= self.baseline_detect_thresh))
        flat_intervals = t[non_flat_points[1:]] - t[non_flat_points[:-1]]
        long_flat_intervals = np.flatnonzero(flat_intervals >= self.baseline_interval)
        if long_flat_intervals.size > 0:
            interval_index = long_flat_intervals[0] + 1
            baseline_end_time = t[non_flat_points[interval_index]]
            return ft.average_voltage(
                v, t, baseline_end_time - self.baseline_interval, baseline_end_time
            )
        else:
            logging.info(
                "Could not find sufficiently flat interval for automatic baseline voltage",
                RuntimeWarning,
            )
            return np.nan

    def voltage_deflection(self, deflect_type=None):
        """Measure deflection (min or max, between start and end if specified).

        Parameters
        ----------
        deflect_type : measure minimal ('min') or maximal ('max') voltage deflection
            If not specified, it will check to see if the current (i) is positive or negative
            between start and end, then choose 'max' or 'min', respectively
            If the current is not defined, it will default to 'min'.

        Returns
        -------
        deflect_v : peak
        deflect_index : index of peak deflection
        """

        deflect_dispatch = {
            "min": np.argmin,
            "max": np.argmax,
        }

        start = self.start
        if not start:
            start = 0
        start_index = ft.find_time_index(self.t, start)

        end = (
            self.end - 0.1
        )  # Let us add -0.1 because we don't expect to find a trough that close to the end of current stimulation
        # This actually helps us ignore cases where the voltage acts funny (i.e. drops mistakenly taken as trough)
        # right at current stimulation end.
        if not end:
            end = self.t[-1]
        end_index = ft.find_time_index(self.t, end)

        if deflect_type is None:
            if self.i is not None:
                halfway_index = ft.find_time_index(self.t, (end - start) / 2.0 + start)
                if self.i[halfway_index] >= 0:
                    deflect_type = "max"
                else:
                    deflect_type = "min"
            else:
                deflect_type = "min"

        deflect_func = deflect_dispatch[deflect_type]

        v_window = self.v[start_index:end_index]
        deflect_index = deflect_func(v_window) + start_index

        return self.v[deflect_index], deflect_index

    def stimulus_amplitude(self):
        """ """
        if self.stimulus_amplitude_calculator is not None:
            return self.stimulus_amplitude_calculator(self)
        else:
            return np.nan

    def estimate_time_constant(self):
        """Calculate the membrane time constant by fitting the voltage response with a
        single exponential.

        Returns
        -------
        tau : membrane time constant in seconds
        """

        # Assumes this is being done on a hyperpolarizing step
        v_peak, peak_index = self.voltage_deflection("min")
        v_baseline = self.sweep_feature("v_baseline")

        if self.start:
            start_index = ft.find_time_index(self.t, self.start)
        else:
            start_index = 0

        frac = 0.1
        search_result = np.flatnonzero(
            self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline
        )
        if not search_result.size:
            return np.nan
            # raise ft.FeatureError("could not find interval for time constant estimate")

        fit_start = self.t[search_result[0] + start_index]
        fit_end = self.t[peak_index]

        # There was one cell with a noisy (?) peak downwards (to -250 mV) unfortunately. That's why we have the if-statement here.
        # You can delete this if-statement if you have normal traces.
        # If this all still didn't work as expected, then hopefully there are more hyperpolarisation traces for which tau can be estimated
        #         if (self.v[peak_index] < -200) :
        #             print("A DOWNWARD PEAK WAS OBSERVED GOING TO LESS THAN 200 MV!!!")
        #             # Look for another local minimum closer to stimulus onset
        #             # We look for a couple of milliseconds after stimulus onset to 50 ms before the downward peak
        #             end_index = (start_index + 50) + np.argmin(self.v[start_index + 50 : peak_index - 1250])
        #             fit_end = self.t[end_index]
        #             fit_start = self.t[start_index + 50]

        a, inv_tau, y0 = ft.fit_membrane_time_constant(
            self.v, self.t, fit_start, fit_end
        )

        return 1.0 / inv_tau

    def estimate_time_constant_at_end(self):
        """Calculate the membrane time constant by fitting the voltage response with a single expontial at the end of a hyperpolarising
        stimulus.

        Returns
        -------
        tau : membrane time constant in seconds
        """

        # Assumes this is being done on a hyperpolarizing step
        v_peak, peak_index = self.voltage_deflection("min")
        v_baseline = self.sweep_feature("v_baseline")
        if self.end:
            start_index = ft.find_time_index(self.t, self.end)
        else:
            start_index = ft.find_time_index(self.t, 0.7)

        frac = 0.1
        search_result = np.flatnonzero(
            self.v[start_index:] >= frac * (v_baseline - v_peak) + v_peak
        )
        if not search_result.size:
            raise ft.FeatureError("Could not find interval for time constant estimate")
        fit_start = self.t[search_result[0] + start_index]
        fit_end = self.t[-1]
        b, inv_tau, A = ft.fit_membrane_time_constant_at_end(
            self.v, self.t, fit_start, fit_end
        )

        return 1.0 / inv_tau

    def estimate_sag(self, peak_width=0.005):
        """Calculate the sag in a hyperpolarizing voltage response.

        Parameters
        ----------
        peak_width : window width to get more robust peak estimate in sec (default 0.005)

        Returns
        -------
        sag : fraction that membrane potential relaxes back to baseline
        sag_ratio: ratio of steady state voltage decrease to the largest voltage decrease
        """

        t = self.t
        v = self.v

        start = self.start
        if not start:
            start = 0

        end = (
            self.end
        )  # To calculate the steady state, not the peak deflection (see code below)
        if not end:
            end = self.t[-1]

        v_peak, peak_index = self.voltage_deflection("min")

        # There was one cell with a noisy (?) peak downwards (to -250 mV) unfortunately. That's why we have the if-statement here.
        # You can delete this if-statement if you have normal traces.
        # If this all still didn't work as expected, then hopefully there are more hyperpolarisation traces for which tau can be estimated
        if self.v[peak_index] < -200:
            print("A DOWNWARD PEAK WAS OBSERVED GOING TO LESS THAN 200 MV!!!")
            # Look for another local minimum closer to stimulus onset
            # A spike should only last about a couple of milliseconds, so let's look a bit before the 'spike'
            peak_index = peak_index - (
                ft.find_time_index(t, 0.12) - ft.find_time_index(t, 0.1)
            )

        # print(t[peak_index])
        v_peak_avg = ft.average_voltage(
            v,
            t,
            start=t[peak_index] - peak_width / 2.0,
            end=t[peak_index] + peak_width / 2.0,
        )
        v_baseline = self.sweep_feature("v_baseline")
        v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, end=end)
        # print('v_stead: ', v_steady)
        # print('v_baseline: ', v_baseline)
        # print('v_peak_avg: ', v_peak_avg)
        # print('denominater=v_stead-v_baseline: ', v_steady-v_baseline)
        # print('numerator=v_peak_avg-v_baseline: ', v_peak_avg-v_baseline)
        sag = (v_peak_avg - v_steady) / (v_peak_avg - v_baseline)
        sag_ratio = (v_peak_avg - v_baseline) / (v_steady - v_baseline)
        # print(sag_ratio)
        return sag, sag_ratio

    def spikes(self):
        """Get all features for each spike as a list of records."""
        return self._spikes_df.to_dict("records")

    def spike_feature(self, key, include_clipped=False, force_exclude_clipped=False):
        """Get specified feature for every spike.

        Parameters
        ----------
        key : feature name
        include_clipped: return values for every identified spike, even when clipping means they will be incorrect/undefined

        Returns
        -------
        spike_feature_values : ndarray of features for each spike
        """

        if not hasattr(self, "_spikes_df"):
            raise AttributeError(
                "EphysSweepFeatureExtractor instance attribute with spike information does not exist yet - have spikes been processed?"
            )

        if len(self._spikes_df) == 0:
            return np.array([])

        if key not in self._spikes_df.columns:
            raise KeyError("requested feature '{:s}' not available".format(key))

        values = self._spikes_df[key].values

        if include_clipped and force_exclude_clipped:
            raise ValueError(
                "include_clipped and force_exclude_clipped cannot both be true"
            )

        if not include_clipped and self.is_spike_feature_affected_by_clipping(key):
            values = values[~self._spikes_df["clipped"].values]
        elif force_exclude_clipped:
            values = values[~self._spikes_df["clipped"].values]

        return values

    def is_spike_feature_affected_by_clipping(self, key):
        return key in self._affected_by_clipping

    def spike_feature_keys(self):
        """Get list of every available spike feature."""
        return self._spikes_df.columns.values.tolist()

    def sweep_feature(self, key, allow_missing=False):
        """Get sweep-level feature (`key`).

        Parameters
        ----------
        key : name of sweep-level feature
        allow_missing : return np.nan if key is missing for sweep (default False)

        Returns
        -------
        sweep_feature : sweep-level feature value
        """

        on_request_dispatch = {
            "v_baseline": self._get_baseline_voltage,
            "tau": self.estimate_time_constant,
            "sag": self.estimate_sag,
            "peak_deflect": self.voltage_deflection,
            "stim_amp": self.stimulus_amplitude,
        }

        if (
            allow_missing
            and key not in self._sweep_features
            and key not in on_request_dispatch
        ):
            return np.nan
        elif key not in self._sweep_features and key not in on_request_dispatch:
            raise KeyError("requested feature '{:s}' not available".format(key))

        if key not in self._sweep_features and key in on_request_dispatch:
            fn = on_request_dispatch[key]
            if fn is not None:
                self._sweep_features[key] = fn()
            else:
                raise KeyError("requested feature '{:s}' not defined".format(key))

        return self._sweep_features[key]

    def process_new_spike_feature(
        self, feature_name, feature_func, affected_by_clipping=False
    ):
        """Add new spike-level feature calculation function

        The function should take this sweep extractor as its argument. Its results
        can be accessed by calling the method spike_feature(<feature_name>).
        """

        if feature_name in self._spikes_df.columns:
            raise KeyError("Feature {:s} already exists for sweep".format(feature_name))

        self._spikes_df[feature_name] = feature_func(self)

        if affected_by_clipping:
            self._affected_by_clipping.append(feature_name)

    def process_new_sweep_feature(self, feature_name, feature_func):
        """Add new sweep-level feature calculation function

        The function should take this sweep extractor as its argument. Its results
        can be accessed by calling the method sweep_feature(<feature_name>).
        """

        if feature_name in self._sweep_features:
            raise KeyError("Feature {:s} already exists for sweep".format(feature_name))

        self._sweep_features[feature_name] = feature_func(self)

    def set_stimulus_amplitude_calculator(self, function):
        self.stimulus_amplitude_calculator = function

    def sweep_feature_keys(self):
        """Get list of every available sweep-level feature."""
        return self._sweep_features.keys()

    def as_dict(self):
        """Create dict of features and spikes."""
        output_dict = self._sweep_features.copy()
        output_dict["spikes"] = self.spikes()
        if self.id is not None:
            output_dict["id"] = self.id
        return output_dict


class EphysSweepSetFeatureExtractor:
    def __init__(
        self,
        t_set=None,
        v_set=None,
        i_set=None,
        start=None,
        end=None,
        filter=10.0,
        dv_cutoff=20.0,
        max_interval=0.005,
        min_height=2.0,
        min_peak=-30.0,
        thresh_frac=0.05,
        baseline_interval=0.1,
        baseline_detect_thresh=0.3,
        id_set=None,
    ):
        """Initialize EphysSweepSetFeatureExtractor object.

        Parameters
        ----------
        t_set : list of ndarray of times in seconds
        v_set : list of ndarray of voltages in mV
        i_set : list of ndarray of currents in pA
        start : start of time window for feature analysis (optional, can be list)
        end : end of time window for feature analysis (optional, can be list)
        filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
        dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
        max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
        min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
        min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
        thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
        baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
        baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
        """

        if t_set is not None and v_set is not None:
            self._set_sweeps(
                t_set,
                v_set,
                i_set,
                start,
                end,
                filter,
                dv_cutoff,
                max_interval,
                min_height,
                min_peak,
                thresh_frac,
                baseline_interval,
                baseline_detect_thresh,
                id_set,
            )
        else:
            self._sweeps = None

    @classmethod
    def from_sweeps(cls, sweep_list):
        """Initialize EphysSweepSetFeatureExtractor object with a list of pre-existing
        sweep feature extractor objects.
        """

        obj = cls()
        obj._sweeps = sweep_list
        return obj

    def _set_sweeps(
        self,
        t_set,
        v_set,
        i_set,
        start,
        end,
        filter,
        dv_cutoff,
        max_interval,
        min_height,
        min_peak,
        thresh_frac,
        baseline_interval,
        baseline_detect_thresh,
        id_set,
    ):
        if type(t_set) != list:
            raise ValueError("t_set must be a list")

        if type(v_set) != list:
            raise ValueError("v_set must be a list")

        if i_set is not None and type(i_set) != list:
            raise ValueError("i_set must be a list")

        if len(t_set) != len(v_set):
            raise ValueError("t_set and v_set must have the same number of items")

        if i_set and len(t_set) != len(i_set):
            raise ValueError("t_set and i_set must have the same number of items")

        if id_set is None:
            id_set = range(len(t_set))
        if len(id_set) != len(t_set):
            raise ValueError("t_set and id_set must have the same number of items")

        sweeps = []
        if i_set is None:
            i_set = [None] * len(t_set)

        if type(start) is not list:
            start = [start] * len(t_set)
            end = [end] * len(t_set)

        sweeps = [
            EphysSweepFeatureExtractor(
                t,
                v,
                i,
                start,
                end,
                filter=filter,
                dv_cutoff=dv_cutoff,
                max_interval=max_interval,
                min_height=min_height,
                min_peak=min_peak,
                thresh_frac=thresh_frac,
                baseline_interval=baseline_interval,
                baseline_detect_thresh=baseline_detect_thresh,
                id=sid,
            )
            for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, end, id_set)
        ]

        self._sweeps = sweeps

    def sweeps(self):
        """Get list of EphysSweepFeatureExtractor objects."""
        return self._sweeps

    def process_spikes(self):
        """Analyze spike features for all sweeps."""
        for sweep in self._sweeps:
            sweep.process_spikes()

    def sweep_features(self, key, allow_missing=False):
        """Get nparray of sweep-level feature (`key`) for all sweeps

        Parameters
        ----------
        key : name of sweep-level feature
        allow_missing : return np.nan if key is missing for sweep (default False)

        Returns
        -------
        sweep_feature : nparray of sweep-level feature values
        """

        return np.array([swp.sweep_feature(key, allow_missing) for swp in self._sweeps])

    def spike_feature_averages(self, key):
        """Get nparray of average spike-level feature (`key`) for all sweeps"""
        return np.array([swp.spike_feature(key).mean() for swp in self._sweeps])


class EphysCellFeatureExtractor:
    # Class constants for specific processing
    SUBTHRESH_MAX_AMP = 0
    SAG_TARGET = -100.0

    def __init__(
        self, ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp=-100
    ):
        """Initialize EphysCellFeatureExtractor object from EphysSweepSetExtractors for
        ramp, short square, and long square sweeps.

        Parameters
        ----------
        dataset : NwbDataSet
        ramps_ext : EphysSweepSetFeatureExtractor prepared with ramp sweeps
        short_squares_ext : EphysSweepSetFeatureExtractor prepared with short square sweeps
        long_squares_ext : EphysSweepSetFeatureExtractor prepared with long square sweeps
        """

        self._ramps_ext = ramps_ext
        self._short_squares_ext = short_squares_ext
        self._long_squares_ext = long_squares_ext

        self._subthresh_min_amp = subthresh_min_amp

        self._features = {
            "ramps": {},
            "short_squares": {},
            "long_squares": {},
        }

        self._spiking_long_squares_ext = None
        self._subthreshold_long_squares_ext = None
        self._subthreshold_membrane_property_ext = None

    def process(self, keys=None):
        """Processes features. Can take a specific key (or set of keys) to do a subset of processing."""

        dispatch = {
            "ramps": self._analyze_ramps,
            "short_squares": self._analyze_short_squares,
            "long_squares": self._analyze_long_squares,
            "long_squares_spiking": self._analyze_long_squares_spiking,
        }

        if keys is None:
            keys = dispatch.keys()

        if type(keys) is not list:
            keys = [keys]

        for k in [j for j in keys if j in dispatch.keys()]:
            dispatch[k]()

    def _analyze_ramps(self):
        ext = self._ramps_ext
        ext.process_spikes()

        self._all_ramps_ext = ext

        # pull out the spiking sweeps
        spiking_sweeps = [
            sweep
            for sweep in self._ramps_ext.sweeps()
            if sweep.sweep_feature("avg_rate") > 0
        ]
        ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps)
        self._ramps_ext = ext

        self._features["ramps"]["spiking_sweeps"] = ext.sweeps()

    def ramps_features(self, all=False):
        if all:
            return self._all_ramps_ext
        else:
            return self._ramps_ext

    def _analyze_short_squares(self):
        ext = self._short_squares_ext
        ext.process_spikes()

        # Need to count how many had spikes at each amplitude; find most; ties go to lower amplitude
        spiking_sweeps = [
            sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0
        ]

        if len(spiking_sweeps) == 0:
            raise ft.FeatureError(
                "No spiking short square sweeps, cannot compute cell features."
            )

        most_common = Counter(map(_short_step_stim_amp, spiking_sweeps)).most_common()
        common_amp, common_count = most_common[0]
        for c in most_common[1:]:
            if c[1] < common_count:
                break
            if c[0] < common_amp:
                common_amp = c[0]

        self._features["short_squares"]["stimulus_amplitude"] = common_amp
        ext = EphysSweepSetFeatureExtractor.from_sweeps(
            [
                sweep
                for sweep in spiking_sweeps
                if _short_step_stim_amp(sweep) == common_amp
            ]
        )
        self._short_squares_ext = ext

        self._features["short_squares"]["common_amp_sweeps"] = ext.sweeps()
        for s in self._features["short_squares"]["common_amp_sweeps"]:
            s.set_stimulus_amplitude_calculator(_short_step_stim_amp)

    def short_squares_features(self):
        return self._short_squares_ext

    def _analyze_long_squares(self):
        self._analyze_long_squares_spiking()
        self._analyze_long_squares_subthreshold()

    def _analyze_long_squares_spiking(self, force_reprocess=False):
        if not force_reprocess and self._spiking_long_squares_ext:
            return

        ext = self._long_squares_ext
        ext.process_spikes()
        self._features["long_squares"]["sweeps"] = ext.sweeps()
        for s in self._features["long_squares"]["sweeps"]:
            s.set_stimulus_amplitude_calculator(_step_stim_amp)

        spiking_indexes = np.flatnonzero(ext.sweep_features("avg_rate"))

        if len(spiking_indexes) == 0:
            raise ft.FeatureError(
                "No spiking long square sweeps, cannot compute cell features."
            )

        amps = ext.sweep_features("stim_amp")  # self.long_squares_stim_amps()
        min_index = np.argmin(amps[spiking_indexes])
        rheobase_index = spiking_indexes[min_index]
        rheobase_i = _step_stim_amp(ext.sweeps()[rheobase_index])

        self._features["long_squares"]["rheobase_extractor_index"] = rheobase_index
        self._features["long_squares"]["rheobase_i"] = rheobase_i
        self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[rheobase_index]
        spiking_sweeps = [
            sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0
        ]
        self._spiking_long_squares_ext = EphysSweepSetFeatureExtractor.from_sweeps(
            spiking_sweeps
        )
        self._features["long_squares"][
            "spiking_sweeps"
        ] = self._spiking_long_squares_ext.sweeps()

        self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope(
            self._spiking_long_squares_ext
        )

    def _analyze_long_squares_subthreshold(self):
        ext = self._long_squares_ext
        subthresh_sweeps = [
            sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") == 0
        ]
        subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(subthresh_sweeps)
        self._subthreshold_long_squares_ext = subthresh_ext

        if len(subthresh_ext.sweeps()) == 0:
            raise ft.FeatureError(
                "No subthreshold long square sweeps, cannot evaluate cell features."
            )

        peaks = subthresh_ext.sweep_features("peak_deflect")
        sags = subthresh_ext.sweep_features("sag")
        sag_eval_levels = np.array(
            [sweep.voltage_deflection()[0] for sweep in subthresh_ext.sweeps()]
        )
        target_level = self.SAG_TARGET
        closest_index = np.argmin(np.abs(sag_eval_levels - target_level))
        self._features["long_squares"]["sag"] = sags[closest_index]
        self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[closest_index]
        self._features["long_squares"]["subthreshold_sweeps"] = subthresh_ext.sweeps()
        for s in self._features["long_squares"]["subthreshold_sweeps"]:
            s.set_stimulus_amplitude_calculator(_step_stim_amp)

        logging.debug("subthresh_sweeps: %d", len(subthresh_sweeps))
        calc_subthresh_sweeps = [
            sweep
            for sweep in subthresh_sweeps
            if sweep.sweep_feature("stim_amp") < self.SUBTHRESH_MAX_AMP
            and sweep.sweep_feature("stim_amp") > self._subthresh_min_amp
        ]

        logging.debug("calc_subthresh_sweeps: %d", len(calc_subthresh_sweeps))
        calc_subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(
            calc_subthresh_sweeps
        )
        self._subthreshold_membrane_property_ext = calc_subthresh_ext
        self._features["long_squares"][
            "subthreshold_membrane_property_sweeps"
        ] = calc_subthresh_ext.sweeps()
        self._features["long_squares"]["input_resistance"] = input_resistance(
            calc_subthresh_ext
        )
        self._features["long_squares"]["tau"] = membrane_time_constant(
            calc_subthresh_ext
        )
        self._features["long_squares"]["v_baseline"] = np.nanmean(
            ext.sweep_features("v_baseline")
        )

    def long_squares_features(self, option=None):
        option_table = {
            "spiking": self._spiking_long_squares_ext,
            "subthreshold": self._subthreshold_long_squares_ext,
            "subthreshold_membrane_property": self._subthreshold_membrane_property_ext,
        }
        if option:
            return option_table[option]

        return self._long_squares_ext

    def long_squares_stim_amps(self, option=None):
        option_table = {
            "spiking": self._spiking_long_squares_ext,
            "subthreshold": self._subthreshold_long_squares_ext,
            "subthreshold_membrane_property": self._subthreshold_membrane_property_ext,
        }
        if option:
            ext = option_table[option]
        else:
            ext = self._long_squares_ext

        return np.array(map(_step_stim_amp, ext.sweeps()))

    def cell_features(self):
        return self._features

    def as_dict(self):
        """Create dict of cell features."""

        # get shallow copies of the sub-type dictionaries
        out = {
            "long_squares": self._features["long_squares"].copy(),
            "short_squares": self._features["short_squares"].copy(),
            "ramps": self._features["ramps"].copy(),
        }

        # convert feature extractor lists to sweep dictionarsweep extract lists
        ls_sweeps = [s.as_dict() for s in out["long_squares"]["sweeps"]]
        ls_spike_sweeps = [s.as_dict() for s in out["long_squares"]["spiking_sweeps"]]
        rheo_sweep = out["long_squares"]["rheobase_sweep"].as_dict()
        ls_sub_sweeps = [
            s.as_dict() for s in out["long_squares"]["subthreshold_sweeps"]
        ]
        ls_sub_mem_sweeps = [
            s.as_dict()
            for s in out["long_squares"]["subthreshold_membrane_property_sweeps"]
        ]
        ss_sweeps = [s.as_dict() for s in out["short_squares"]["common_amp_sweeps"]]
        ramp_sweeps = [s.as_dict() for s in out["ramps"]["spiking_sweeps"]]

        out["long_squares"]["sweeps"] = ls_sweeps
        out["long_squares"]["spiking_sweeps"] = ls_spike_sweeps
        out["long_squares"]["subthreshold_sweeps"] = ls_sub_sweeps
        out["long_squares"]["subthreshold_membrane_property_sweeps"] = ls_sub_mem_sweeps
        out["long_squares"]["rheobase_sweep"] = rheo_sweep
        out["short_squares"]["common_amp_sweeps"] = ss_sweeps
        out["ramps"]["spiking_sweeps"] = ramp_sweeps

        return out


def input_resistance(ext):
    """Estimate input resistance in MOhms, assuming all sweeps in passed extractor
    are hyperpolarizing responses."""

    sweeps = ext.sweeps()
    if not sweeps:
        raise ft.FeatureError("no sweeps available for input resistance calculation")

    v_vals = []
    i_vals = []
    for sweep in sweeps:
        if sweep.i is None:
            raise ft.FeatureError(
                "cannot calculate input resistance: i not defined for a sweep"
            )

        v_peak, min_index = sweep.voltage_deflection("min")
        v_vals.append(v_peak)
        i_vals.append(sweep.i[min_index])

    v = np.array(v_vals)
    i = np.array(i_vals)

    if len(v) == 1:
        # If there's just one sweep, we'll have to use its own baseline to estimate
        # the input resistance
        v = np.append(v, sweeps[0].sweep_feature("v_baseline"))
        i = np.append(i, 0.0)

    A = np.vstack([i, np.ones_like(i)]).T
    m, c = np.linalg.lstsq(A, v)[0]

    return m * 1e3


def membrane_time_constant(ext):
    """Average the membrane time constant values estimated from each sweep in passed extractor."""

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
        avg_tau = np.nanmean(ext.sweep_features("tau"))
    return avg_tau


def fit_fi_slope(ext):
    """Fit the rate and stimulus amplitude to a line and return the slope of the fit."""
    if len(ext.sweeps()) < 2:
        raise ft.FeatureError(
            "Cannot fit f-I curve slope with less than two suprathreshold sweeps"
        )

    x = np.array(map(_step_stim_amp, ext.sweeps()))
    y = ext.sweep_features("avg_rate")

    A = np.vstack([x, np.ones_like(x)]).T
    m, c = np.linalg.lstsq(A, y)[0]

    return m


def reset_long_squares_start(when):
    global LONG_SQUARES_START, LONG_SQUARES_END
    delta = LONG_SQUARES_END - LONG_SQUARES_START
    LONG_SQUARES_START = when
    LONG_SQUARES_END = when + delta


def cell_extractor_for_nwb(
    dataset, ramps, short_squares, long_squares, subthresh_min_amp=-100
):
    """Initialize EphysCellFeatureExtractor object from NWB data set

    Parameters
    ----------
    dataset : NwbDataSet
    ramps : list of sweep numbers of ramp sweeps
    short_squares : list of sweep numbers of short square sweeps
    long_squares : list of sweep numbers of long square sweeps
    """

    if len(short_squares) == 0:
        raise ft.FeatureError("no short square sweep numbers provided")
    if len(ramps) == 0:
        raise ft.FeatureError("no ramp sweep numbers provided")
    if len(long_squares) == 0:
        raise ft.FeatureError("no long_square sweep numbers provided")

    ramps_ext = extractor_for_nwb_sweeps(dataset, ramps, fixed_start=RAMPS_START)

    temp_short_sq_ext = extractor_for_nwb_sweeps(dataset, short_squares)
    t_set = [s.t for s in temp_short_sq_ext.sweeps()]
    v_set = [s.v for s in temp_short_sq_ext.sweeps()]
    cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(
        v_set, t_set, SHORT_SQUARES_WINDOW_START, SHORT_SQUARES_WINDOW_END
    )

    thresh_frac = max(thresh_frac, 0.1)

    short_squares_ext = extractor_for_nwb_sweeps(
        dataset, short_squares, dv_cutoff=cutoff, thresh_frac=thresh_frac
    )
    long_squares_ext = extractor_for_nwb_sweeps(
        dataset,
        long_squares,
        fixed_start=LONG_SQUARES_START,
        fixed_end=LONG_SQUARES_END,
    )

    return EphysCellFeatureExtractor(
        ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp
    )


def extractor_for_nwb_sweeps(
    dataset,
    sweep_numbers,
    fixed_start=None,
    fixed_end=None,
    dv_cutoff=20.0,
    thresh_frac=0.05,
):
    v_set = []
    t_set = []
    i_set = []

    start = []
    end = []

    for sweep_number in sweep_numbers:
        data = dataset.get_sweep(sweep_number)
        v = data["response"] * 1e3  # mV
        i = data["stimulus"] * 1e12  # pA
        hz = data["sampling_rate"]
        dt = 1.0 / hz
        t = np.arange(0, len(v)) * dt  # sec

        s, e = dt * np.array(data["index_range"])
        v_set.append(v)
        i_set.append(i)
        t_set.append(t)
        start.append(s)
        end.append(e)

    if fixed_start and not fixed_end:
        start = [fixed_start] * len(end)
    elif fixed_start and fixed_end:
        start = fixed_start
        end = fixed_end

    return EphysSweepSetFeatureExtractor(
        t_set,
        v_set,
        i_set,
        start=start,
        end=end,
        dv_cutoff=dv_cutoff,
        thresh_frac=thresh_frac,
        id_set=sweep_numbers,
    )


def _step_stim_amp(sweep):
    t_index = ft.find_time_index(sweep.t, sweep.start)
    return sweep.i[t_index + 1]


def _short_step_stim_amp(sweep):
    t_index = ft.find_time_index(sweep.t, sweep.start)
    return sweep.i[t_index + 1 :].max()
