import numpy as np
from scipy import stats as spstats
import torch

def syn_current(duration=120, dt=0.1, t_on=10, curr_level=5e-4, seed=None):
    duration = duration
    t_off = duration - t_on
    t = np.arange(0, duration + dt, dt)

    # external current
    A_soma = np.pi * ((70.0 * 1e-4) ** 2)  # cm2
    I_inj = np.zeros_like(t)
    I_inj[int(np.round(t_on / dt)) : int(np.round(t_off / dt))] = (
        curr_level / A_soma
    )  # muA/cm2

    return I_inj, t_on, t_off, dt, t, A_soma


def HHsimulator(V0, params, dt, t, I_inj, seed=None):
    """Simulates the Hodgkin-Huxley model for a specified time duration and current

    Parameters
    ----------
    V0 : float
        Voltage at first time step
    params : np.array, 1d of length dim_param
        Parameter vector
    dt : float
        Timestep
    t : array
        Numpy array with the time steps
    I : array
        Numpy array with the input current
    seed : int
    """

    gbar_Na = params[0, 0]  # mS/cm2
    gbar_Na.astype(float)
    gbar_K = params[0, 1]  # mS/cm2
    gbar_K.astype(float)

    # fixed parameters
    g_leak = 0.1  # mS/cm2
    gbar_M = 0.07  # mS/cm2
    tau_max = 6e2  # ms
    Vt = -60.0  # mV
    nois_fact = 0.1  # uA/cm2
    E_leak = -70.0  # mV
    C = 1.0  # uF/cm2
    E_Na = 53  # mV
    E_K = -107  # mV

    tstep = float(dt)

    if seed is not None:
        rng = np.random.RandomState(seed=seed)
    else:
        rng = np.random.RandomState()

    ####################################
    # kinetics
    def efun(z):
        if np.abs(z) < 1e-4:
            return 1 - z / 2
        else:
            return z / (np.exp(z) - 1)

    def alpha_m(x):
        v1 = x - Vt - 13.0
        return 0.32 * efun(-0.25 * v1) / 0.25

    def beta_m(x):
        v1 = x - Vt - 40
        return 0.28 * efun(0.2 * v1) / 0.2

    def alpha_h(x):
        v1 = x - Vt - 17.0
        return 0.128 * np.exp(-v1 / 18.0)

    def beta_h(x):
        v1 = x - Vt - 40.0
        return 4.0 / (1 + np.exp(-0.2 * v1))

    def alpha_n(x):
        v1 = x - Vt - 15.0
        return 0.032 * efun(-0.2 * v1) / 0.2

    def beta_n(x):
        v1 = x - Vt - 10.0
        return 0.5 * np.exp(-v1 / 40)

    # steady-states and time constants
    def tau_n(x):
        return 1 / (alpha_n(x) + beta_n(x))

    def n_inf(x):
        return alpha_n(x) / (alpha_n(x) + beta_n(x))

    def tau_m(x):
        return 1 / (alpha_m(x) + beta_m(x))

    def m_inf(x):
        return alpha_m(x) / (alpha_m(x) + beta_m(x))

    def tau_h(x):
        return 1 / (alpha_h(x) + beta_h(x))

    def h_inf(x):
        return alpha_h(x) / (alpha_h(x) + beta_h(x))

    # slow non-inactivating K+
    def p_inf(x):
        v1 = x + 35.0
        return 1.0 / (1.0 + np.exp(-0.1 * v1))

    def tau_p(x):
        v1 = x + 35.0
        return tau_max / (3.3 * np.exp(0.05 * v1) + np.exp(-0.05 * v1))

    ####################################
    # simulation from initial point
    V = np.zeros_like(t)  # voltage
    n = np.zeros_like(t)
    m = np.zeros_like(t)
    h = np.zeros_like(t)
    p = np.zeros_like(t)

    V[0] = float(V0)
    n[0] = n_inf(V[0])
    m[0] = m_inf(V[0])
    h[0] = h_inf(V[0])
    p[0] = p_inf(V[0])

    for i in range(1, t.shape[0]):
        tau_V_inv = (
            (m[i - 1] ** 3) * gbar_Na * h[i - 1]
            + (n[i - 1] ** 4) * gbar_K
            + g_leak
            + gbar_M * p[i - 1]
        ) / C
        V_inf = (
            (m[i - 1] ** 3) * gbar_Na * h[i - 1] * E_Na
            + (n[i - 1] ** 4) * gbar_K * E_K
            + g_leak * E_leak
            + gbar_M * p[i - 1] * E_K
            + I_inj[i - 1]
            + nois_fact * rng.randn() / (tstep**0.5)
        ) / (tau_V_inv * C)
        V[i] = V_inf + (V[i - 1] - V_inf) * np.exp(-tstep * tau_V_inv)
        n[i] = n_inf(V[i]) + (n[i - 1] - n_inf(V[i])) * np.exp(-tstep / tau_n(V[i]))
        m[i] = m_inf(V[i]) + (m[i - 1] - m_inf(V[i])) * np.exp(-tstep / tau_m(V[i]))
        h[i] = h_inf(V[i]) + (h[i - 1] - h_inf(V[i])) * np.exp(-tstep / tau_h(V[i]))
        p[i] = p_inf(V[i]) + (p[i - 1] - p_inf(V[i])) * np.exp(-tstep / tau_p(V[i]))

    return np.array(V).reshape(-1, 1)


def calculate_summary_statistics(x):
    """Calculate summary statistics

    Parameters
    ----------
    x : output of the simulator

    Returns
    -------
    np.array, summary statistics
    """
    I_inj, t_on, t_off, dt, t, A_soma = syn_current()

    n_mom = 4
    n_summary = 7

    n_summary = np.minimum(n_summary, n_mom + 3)

    t = x["time"]
    dt = x["dt"]

    # initialise array of spike counts
    v = np.array(x["data"])

    # put everything to -10 that is below -10 or has negative slope
    ind = np.where(v < -10)
    v[ind] = -10
    ind = np.where(np.diff(v) < 0)
    v[ind] = -10

    # remaining negative slopes are at spike peaks
    ind = np.where(np.diff(v) < 0)
    spike_times = np.array(t)[ind]
    spike_times_stim = spike_times[(spike_times > t_on) & (spike_times < t_off)]

    #breakpoint()
    # number of spikes
    if spike_times_stim.shape[0] > 0:
        spike_times_stim = spike_times_stim[
            np.append(1, np.diff(spike_times_stim)) > 0.5
        ]
    #breakpoint()
    # resting potential and std
    rest_pot = np.mean(x["data"][t < t_on])
    rest_pot_std = np.std(x["data"][int(0.9 * t_on / dt) : int(t_on / dt)])

    # moments
    std_pw = np.power(
        np.std(x["data"][(t > t_on) & (t < t_off)]), np.linspace(3, n_mom, n_mom - 2)
    )
    std_pw = np.concatenate((np.ones(1), std_pw))
    moments = (
        spstats.moment(
            x["data"][(t > t_on) & (t < t_off)], np.linspace(2, n_mom, n_mom - 1)
        )
        / std_pw
    )

    # concatenation of summary statistics
    sum_stats_vec = np.concatenate((
        np.array([spike_times_stim.shape[0]]),
        np.array([
            rest_pot,
            rest_pot_std,
            np.mean(x["data"][(t > t_on) & (t < t_off)]),
        ]),
        moments,
    ))
    sum_stats_vec = sum_stats_vec[0:n_summary]

    return sum_stats_vec


def calculate_summary_statistics_torch(x,t,dt):
    """Calculate summary statistics

    Parameters
    ----------
    x : output of the simulator

    Returns
    -------
    np.array, summary statistics
    """
    I_inj, t_on, t_off, dt, t, A_soma = syn_current()

    n_mom = torch.tensor([4]).to(x.device)
    n_summary = torch.tensor([7]).to(x.device)

    n_summary = torch.minimum(n_summary, n_mom + 3)

    #t = x["time"]
    #dt = x["dt"]

    # initialise array of spike counts
    v = x.clone().detach()

    # put everything to -10 that is below -10 or has negative slope
    ind = torch.where(v < -10)
    v[ind] = -10
    ind = torch.where(torch.diff(v) < 0)
    v[ind] = -10

    # remaining negative slopes are at spike peaks
    ind = torch.where(torch.diff(v) < 0)
    #print(v.shape,ind[0],t)
    spike_times = torch.tensor(t)[ind[0]].to(x.device)
    spike_times_stim = spike_times[(spike_times > t_on) & (spike_times < t_off)]

    #breakpoint()
    # number of spikes
    if spike_times_stim.shape[0] > 0:
        spike_times_stim = spike_times_stim[
            torch.cat((torch.tensor([1]).to(x.device), torch.diff(spike_times_stim))) > 0.5
        ]

    # resting potential and std
    rest_pot = torch.mean(x[t < t_on])
    rest_pot_std = torch.std(x[int(0.9 * t_on / dt) : int(t_on / dt)])
    #breakpoint()
    # moments
    std_pw = torch.pow(
        torch.std(x[(t > t_on) & (t < t_off)]), torch.linspace(3, n_mom.item(), n_mom.item() - 2).to(x.device)
    )
    std_pw = torch.cat((torch.ones(1).to(x.device), std_pw))
    centered_data = v[(t > t_on) & (t < t_off)] - torch.mean(v[(t > t_on) & (t < t_off)])
    moments = torch.stack([torch.mean(centered_data ** i) for i in range(2, n_mom + 1)]) / std_pw

    # concatenation of summary statistics
    sum_stats_vec = torch.cat((
        torch.tensor([spike_times_stim.shape[0]]).to(x.device),
        torch.tensor([
            rest_pot,
            rest_pot_std,
            torch.mean(v[(t > t_on) & (t < t_off)]),
        ]).to(x.device),
        moments,
    ))
    sum_stats_vec = sum_stats_vec[0:n_summary]

    return sum_stats_vec