import tensorflow as tf


class HodgkinHuxleyODE:
    def __init__(
        self,
        g_leak: float = 0.1,
        V0: float = -70.0,
        Vt: float = -60.0,
        tau_max: float = 6e2,
        I_on: float = 10.0,
        I_off: float = 50.0,
        curr_level=5e-4,
    ):
        self.g_leak = tf.constant(g_leak)
        self.V0 = tf.constant(V0)
        self.Vt = tf.constant(Vt)
        self.tau_max = tf.constant(tau_max)
        self.I_on = tf.constant(I_on)
        self.I_off = tf.constant(I_off)
        self.curr_level = tf.constant(curr_level)
        self.A_soma = tf.constant(3.1415 * ((70.0 * 1e-4) ** 2))

    def efun(self, z: tf.Tensor) -> tf.Tensor:
        mask = tf.abs(z) < 1e-4
        new_z = tf.where(mask, 1 - z / 2, z / (tf.exp(z) - 1))
        return new_z

    def I_in(self, t: float) -> tf.Tensor:
        return tf.cond(
            tf.logical_and(t > self.I_on, t < self.I_off),
            lambda: self.curr_level / self.A_soma,
            lambda: tf.constant(0.0, dtype=tf.float32),
        )

    def alpha_m(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 13.0
        return 0.32 * self.efun(-0.25 * v1) / 0.25

    def beta_m(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 40
        return 0.28 * self.efun(0.2 * v1) / 0.2

    def alpha_h(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 17.0
        return 0.128 * tf.exp(-v1 / 18.0)

    def beta_h(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 40.0
        return 4.0 / (1 + tf.exp(-0.2 * v1))

    def alpha_n(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 15.0
        return 0.032 * self.efun(-0.2 * v1) / 0.2

    def beta_n(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x - self.Vt - 10.0
        return 0.5 * tf.exp(-v1 / 40)

    def tau_n(self, x: tf.Tensor) -> tf.Tensor:
        return 1 / (self.alpha_n(x) + self.beta_n(x))

    def n_inf(self, x: tf.Tensor) -> tf.Tensor:
        return self.alpha_n(x) / (self.alpha_n(x) + self.beta_n(x))

    def tau_m(self, x: tf.Tensor) -> tf.Tensor:
        return 1 / (self.alpha_m(x) + self.beta_m(x))

    def m_inf(self, x: tf.Tensor) -> tf.Tensor:
        return self.alpha_m(x) / (self.alpha_m(x) + self.beta_m(x))

    def tau_h(self, x: tf.Tensor) -> tf.Tensor:
        return 1 / (self.alpha_h(x) + self.beta_h(x))

    def h_inf(self, x: tf.Tensor) -> tf.Tensor:
        return self.alpha_h(x) / (self.alpha_h(x) + self.beta_h(x))

    def p_inf(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x + 35.0
        return 1.0 / (1.0 + tf.exp(-0.1 * v1))

    def tau_p(self, x: tf.Tensor) -> tf.Tensor:
        v1 = x + 35.0
        return self.tau_max / (3.3 * tf.exp(0.05 * v1) + tf.exp(-0.05 * v1))

    def ode_step(
        self,
        state: tf.Tensor,
        params: tf.Tensor,
        t: float,
        dt: float,
    ):
        V, n, m, h, p = tf.unstack(state, axis=-1)
        g_Na, g_K, g_M, C, E_leak, E_Na, E_K = tf.unstack(params, axis=-1)

        dV = (
            (m**3) * g_Na * h * (E_Na - V)
            + (n**4) * g_K * (E_K - V)
            + g_M * p * (E_K - V)
            + self.g_leak * (E_leak - V)
            + self.I_in(t)
        )

        dV = C * dV
        dn = (self.n_inf(V) - n) / self.tau_n(V)
        dm = (self.m_inf(V) - m) / self.tau_m(V)
        dh = (self.h_inf(V) - h) / self.tau_h(V)
        dp = (self.p_inf(V) - p) / self.tau_p(V)

        V_new = V + dV * dt
        n_new = n + dn * dt
        m_new = m + dm * dt
        h_new = h + dh * dt
        p_new = p + dp * dt

        return tf.stack([V_new, n_new, m_new, h_new, p_new], axis=-1)

    @tf.function(jit_compile=True)
    def solve_ode(self, theta: tf.Tensor, t_final: float = 60.0, dt: float = 0.01):
        num_steps = tf.cast(tf.math.ceil(t_final / dt), tf.int32)

        theta = tf.reshape(theta, (-1, 7))
        V0 = tf.repeat(self.V0, repeats=tf.shape(theta)[0], axis=0)

        state = tf.stack(
            [V0, self.n_inf(V0), self.m_inf(V0), self.h_inf(V0), self.p_inf(V0)],
            axis=-1,
        )

        V_trace = tf.zeros((num_steps, tf.shape(theta)[0]), dtype=tf.float32)

        def condition(i, _state, _V_trace):
            return i < num_steps

        def body(i, state, V_trace):
            state = self.ode_step(state, theta, float(i) * dt, dt)
            V_trace = tf.tensor_scatter_nd_update(V_trace, [i], state[:, 0])
            i = i + 1

            return i, state, V_trace

        i = tf.constant(0)
        i, state, V_trace = tf.while_loop(condition, body, [i, state, V_trace])

        V_trace = tf.transpose(V_trace, perm=[1, 0])

        return V_trace[..., ::30]
