import jax
import jax.numpy as jnp
from flax import nnx


class SiLuFFN(nnx.Module):
    """ Sigmoid Linear Unit FFN layer
    input shape: (batch, context_len, feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, ffn_feature: int, init_scalar: float, key, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=2)
        initializer = jax.nn.initializers.variance_scaling(init_scalar, 'fan_in', 'normal')

        self.w_in = nnx.Param(initializer(keys[0], shape=(feature, ffn_feature), dtype=dtype), name='FFN_W_IN')
        self.w_out = nnx.Param(initializer(keys[1], shape=(ffn_feature, feature), dtype=dtype), name='FFN_W_OUT')

    def __call__(self, x):
        x = x @ self.w_in.value
        x = x * jax.nn.sigmoid(x)
        return x @ self.w_out.value


class GeGluFFN(nnx.Module):
    """ GeGLU FFN layer
    input shape: (batch, context_len, feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, ffn_feature: int, init_scalar: float, key, beta=1.0, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=3)
        initializer = jax.nn.initializers.variance_scaling(init_scalar, 'fan_in', 'normal')

        self.beta = beta
        self.w_gelu = nnx.Param(initializer(keys[0], shape=(feature, ffn_feature), dtype=dtype), name='FFN_W_GELU')
        self.w_direct = nnx.Param(initializer(keys[1], shape=(feature, ffn_feature), dtype=dtype), name='FFN_W_DIR')
        self.w_outside = nnx.Param(initializer(keys[2], shape=(ffn_feature, feature), dtype=dtype), name='FFN_W_OUT')

    def __call__(self, x):
        s = self.beta * jax.nn.gelu(x @ self.w_gelu.value)
        v = x @ self.w_direct.value
        return (s * v) @ self.w_outside.value


class SwiGluFFN(nnx.Module):
    """ SwiGLU FFN layer
    input shape: (batch, context_len, feature)
    output shape: (batch, context_len, feature)
    """
    def __init__(self, feature: int, ffn_feature: int, init_scalar: float, key, beta=1.0, dtype=jnp.float32):
        super().__init__()
        keys = jax.random.split(key, num=3)
        initializer = jax.nn.initializers.variance_scaling(init_scalar, 'fan_in', 'normal')

        self.beta = beta
        self.w_swish = nnx.Param(initializer(keys[0], shape=(feature, ffn_feature), dtype=dtype), name='FFN_W_SWISH')
        self.w_direct = nnx.Param(initializer(keys[1], shape=(feature, ffn_feature), dtype=dtype), name='FFN_W_DIR')
        self.w_outside = nnx.Param(initializer(keys[2], shape=(ffn_feature, feature), dtype=dtype), name='FFN_W_OUT')

    def __call__(self, x):
        i = x @ self.w_swish.value
        s = jax.nn.swish(self.beta * i)
        v = x @ self.w_direct.value
        return (s * v) @ self.w_outside.value
