from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn

parallel_scan = jax.lax.associative_scan


# Parallel scan operations
@jax.vmap
def binary_operator_diag(q_i, q_j):
    """Binary operator for parallel scan of linear recurrence"""
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j


@jax.custom_gradient
def ATan(mem):
    """
    heaviside step function with surrogate gradient.
    """
    spike = jnp.greater(mem, jnp.zeros_like(mem)).astype(dtype=jnp.float32)
    d = 1 / (1 + jnp.power((jnp.pi * mem), 2))
    return spike, lambda g: (g * d, )


def beta_init(key, shape, tau, dtype=jnp.float32):
    return jax.scipy.special.logit(jnp.ones(shape=shape, dtype=dtype) * tau)


class ParallelLIF(nn.Module):
    d_hidden: int
    tau: float = 0.5
    th: float = 0.5

    def setup(self):
        self.key = jax.random.key(0)
        self.beta = self.param("beta", beta_init, (self.d_hidden,), self.tau)
        self.step_func = ATan

    def __call__(self, u, mem=None):
        tau = jax.nn.sigmoid(self.beta)
        if mem is None:
            tau = jnp.repeat(tau[None, ...], u.shape[0], axis=0)
            _, mem = parallel_scan(binary_operator_diag, (tau, u * (1 - tau)))
        else:
            mem = tau * mem + u * (1 - tau)

        # Vth is stochastic during parallel training, and deterministic during sequential rollout
        if mem.shape[0] > 1:
            # print('training')
            spk = self.step_func(mem - self.th + 0.5 - jax.random.uniform(self.key, u.shape))    
        else:
            # print('rollout')
            spk = self.step_func(mem - self.th)
        
        return spk, mem


class LIFGateUnit(nn.Module):

    hidden_size: int
    output_size: int

    def setup(self):
        self.gate_func = ParallelLIF(self.hidden_size * 2)
        self.linear_gate = nn.Dense(self.hidden_size * 2, use_bias=False)
        self.linear_in = nn.Dense(self.hidden_size * 2, dtype=jnp.complex64, use_bias=False)
        self.linear_out = nn.Dense(self.output_size, use_bias=False)
        self.layer_norm = nn.LayerNorm(use_bias=False, use_scale=False)

    def __call__(self, mem, x):
        if mem is None:
            m12, m34 = [None, None]
        else:
            m12, m3, m4 = jnp.split(mem, [self.hidden_size * 2, self.hidden_size * 3,], axis=-1)
            m34 = m3 + 1j * m4

        # calculate gates
        s, m12 = self.gate_func(self.linear_gate(x), m12)
        s1, s2 = jnp.split(s, [self.hidden_size,], axis=-1)

        # preparing variables for the recurrence
        u1, u2 = jnp.split(self.linear_in(x.astype(jnp.complex64)), [self.hidden_size,], axis=-1)

        # limit c in the unit disk to avoid gradient explosion
        mod = jnp.sqrt(u1.real ** 2 + u1.imag ** 2 + 1)
        mod_tanh = jnp.tanh(mod)
        c = u1 / mod * mod_tanh

        skipped_c = c * s1 + (1 - s1)

        # main recurrent process
        if m34 is None:
            _, m34 = parallel_scan(binary_operator_diag, (skipped_c, u2 * s1))
        else:
            m34 = skipped_c * m34 + u2 * s1

        out = m34 * s2 + u2 * (1 - s2)
        out = jnp.concatenate([out.real, out.imag], axis=-1)
        z = self.layer_norm(self.linear_out(out))

        mem = jnp.concatenate([m12, m34.real, m34.imag], axis=-1)

        return mem[-1:], z


class StackedLIFGate(nn.Module):
    hidden_size: int
    output_size: int
    num_layers: int = 1
    def setup(self):
        self.rnns = [LIFGateUnit(self.hidden_size, self.output_size) for i in range(self.num_layers)]

    def __call__(self, mem, x):
        hidden_states = []
        z = x
        for i, layer in enumerate(self.rnns):
            h, z = layer(mem[i], z)
            hidden_states.append(h)
        return hidden_states, z


# Here we call vmap to parallelize across a batch of input sequences
BatchStackedEncoderModel = nn.vmap(
    StackedLIFGate,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None, "dropout": None, "batch_stats": None, "cache": 0, "prime": None},
    split_rngs={"params": False, "dropout": True},
    axis_name="batch",
)