import flax.linen as nn
import jax

from egxc.utils.typing import PRECISION


class ScaledSigmoid(nn.Module):
    """
    Sigmoid activation function with learnable scaling.

    Computes: scale * sigmoid(x / scale) + constant_y_offset
    """

    initial_scale: float = 1.0
    constant_y_offset: float = 0.0

    def setup(self) -> None:
        self.scale = self.param(
            'scale', nn.initializers.constant(self.initial_scale), (), PRECISION.local_nn
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        """Apply scaled sigmoid activation."""
        return self.scale * nn.sigmoid(x / self.scale) + self.constant_y_offset


class LinearSkip(nn.Module):
    """
    Linear layer with skip connection.

    This layer computes: output = input + W @ input + b
    where W is initialized to small values around zero.
    """

    output_dim: int
    use_bias: bool = True
    kernel_init: nn.initializers.Initializer = nn.initializers.truncated_normal(
        stddev=0.0625,  # initialize close to identity by default
        lower=-2,
        upper=2,
    )

    def setup(self):
        """Initialize weights to be close to the identity transformation."""
        self.linear = nn.Dense(
            self.output_dim,
            kernel_init=self.kernel_init,
            use_bias=self.use_bias,
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        """Apply linear transformation with skip connection."""
        return x + self.linear(x)
