import jax.numpy as jnp
from jax import custom_vjp
import flax.linen as nn


def torch_compatible_dense(in_features, out_features):
    bound = 1 / jnp.sqrt(in_features) if in_features > 0 else 0

    return nn.Dense(
        features=out_features,
        kernel_init=nn.initializers.variance_scaling(1.0 / 3.0, "fan_in", "uniform"),
        bias_init=nn.initializers.uniform(scale=bound),
    )


def ones_dense(in_features, out_features):
    return nn.Dense(
        features=out_features,
        kernel_init=nn.initializers.ones,
        bias_init=nn.initializers.zeros,
    )


def stable_norm(x, axis=-1, keepdims=False, epsilon=1e-12):
    """Stable vector norm calculation with epsilon to avoid zero gradients."""
    return jnp.sqrt(jnp.sum(x**2, axis=axis, keepdims=keepdims) + epsilon)


def safe_reciprocal(x, epsilon=1e-12):
    """Safe reciprocal operation to avoid division by zero."""
    return 1.0 / (jnp.abs(x) + epsilon)


def safe_div(x, y, epsilon=1e-12):
    """Safe division operation."""
    return x / (jnp.abs(y) + epsilon) * jnp.sign(y)


def stable_softmax(x, axis=-1):
    """Numerically stable softmax with custom gradient."""
    max_x = jnp.max(x, axis=axis, keepdims=True)
    shifted = x - max_x
    exp_shifted = jnp.exp(shifted)
    sum_exp = jnp.sum(exp_shifted, axis=axis, keepdims=True)
    return exp_shifted / sum_exp


def safe_power(x, p, epsilon=1e-12):
    """Safe power operation to avoid numerical instability with small numbers."""
    sign = jnp.sign(x)
    abs_x = jnp.abs(x) + epsilon
    result = sign * (abs_x**p)
    return result
