import jax
import jax.numpy as jnp
import flax.linen as nn
from functools import partial
from typing import Callable, List, Tuple, Union, Dict, Any

from equiv_eikonal.utils import torch_compatible_dense, ones_dense

# ----------------------------
# Consistent initialization definitions.
# ----------------------------


# Define activation functions with jax.jit for better optimization
ACTS = {
    "tanh": jnp.tanh,
    "atan": jnp.arctan,
    "sigmoid": jax.nn.sigmoid,
    "softplus": jax.nn.softplus,
    "relu": jax.nn.relu,
    "exp": jnp.exp,
    "elu": jax.nn.elu,
    "gelu": jax.nn.gelu,
    "sin": jnp.sin,
    "sinc": jax.jit(lambda z: jnp.where(z == 0, jnp.ones_like(z), jnp.sin(z) / z)),
    "linear": lambda z: z,
    "abs_linear": jnp.abs,
    "gauss": jax.jit(lambda z: jnp.exp(-(z**2))),
    "swish": jax.jit(lambda z: z * jax.nn.sigmoid(z)),
    "laplace": jax.jit(lambda z: jnp.exp(-jnp.abs(z))),
    "gauslace": jax.jit(lambda z: jnp.exp(-(z**2)) + jnp.exp(-jnp.abs(z))),
}


class AdaptiveActivation(nn.Module):
    """Layer for adaptive activation functions."""

    act_name: str
    adapt: bool
    n: float
    act: Callable

    def setup(self):
        # Initialize parameter in setup if adaptive
        if self.adapt:
            self.a = self.param("a", nn.initializers.ones, (1,))

    @nn.compact
    def __call__(self, x):
        # Use functional pattern for conditional logic
        if self.adapt:
            return self.act(self.n * self.a * x)
        else:
            return self.act(self.n * x)


class ActivationFactory:
    """Factory for creating activation functions efficiently."""

    @staticmethod
    def create_activation(act_spec):
        """Create activation function from specification."""
        if callable(act_spec):
            return act_spec

        if not isinstance(act_spec, str):
            raise ValueError("'act' must be either a 'str' or a 'callable'")

        # Handle direct activation lookup
        if "-" not in act_spec:
            if act_spec in ACTS:
                return ACTS[act_spec]
            raise ValueError(f"Unsupported activation: {act_spec}")

        # Parse adaptive activation
        parts = act_spec.split("-")
        if len(parts) != 3:
            raise ValueError(
                "Adaptive activation format should be '(ad)-activation_name-n'"
            )

        adapt = parts[0] == "ad"
        act_name = parts[1]

        if act_name not in ACTS:
            raise ValueError(f"Unsupported activation: {act_name}")

        try:
            n = float(parts[2])
        except ValueError:
            n = 1.0

        act_func = ACTS[act_name]

        # Return a factory function that creates the module when called
        return lambda x: AdaptiveActivation(
            act_name=act_name, adapt=adapt, n=n, act=act_func
        )(x)


# Function for backward compatibility
def Activation(act):
    """Backward-compatible interface for activation creation."""
    return ActivationFactory.create_activation(act)


class DenseBody(nn.Module):
    input_dim: int
    nu: Union[int, List[int]]
    nl: int
    out_dim: int = 1
    act: Union[str, Callable] = "ad-gauss-1"
    out_act: Union[str, Callable] = "linear"

    @nn.compact
    def __call__(self, x):
        # Process nu parameter
        if isinstance(self.nu, int):
            hidden_dims = [self.nu] * self.nl
        else:
            assert (
                isinstance(self.nu, (list, tuple)) and len(self.nu) == self.nl
            ), "Number of hidden layers 'nl' must match the length of 'nu'"
            hidden_dims = self.nu

        # Define weight initializer (kaiming_normal / He initialization)
        kernel_init = nn.initializers.variance_scaling(2.0, "fan_in", "normal")
        bias_init = nn.initializers.zeros

        # Input layer
        x = torch_compatible_dense(
            in_features=self.input_dim, out_features=hidden_dims[0]
        )(x)
        x = Activation(self.act)(x)

        # Hidden layers
        for i in range(1, self.nl):
            x = nn.Dense(
                features=hidden_dims[i],
                kernel_init=kernel_init,
                bias_init=bias_init,
                name=f"hidden_layer_{i}",
            )(x)

            x = Activation(self.act)(x)

        # Output layer
        x = nn.Dense(
            features=self.out_dim,
            kernel_init=kernel_init,
            bias_init=bias_init,
            name="output_layer",
        )(x)
        x = Activation(self.out_act)(x)

        return x
