from typing import Callable, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.initializers import Initializer

from egxc.utils.typing import (
    PRECISION,
    FloatN,
)

from .layers import LinearSkip


class MLP(nn.Module):
    dims: Sequence[int]
    activation: Callable[[jax.Array], jax.Array]
    dtype: jnp.dtype
    add_skip_connection: bool = False
    use_bias: bool = True
    init_last_layer_to_zero: bool = False
    last_layer_use_bias: bool = True
    kernel_init: Initializer = nn.initializers.lecun_normal()

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        for dim in self.dims[:-1]:
            if self.add_skip_connection:
                x = LinearSkip(dim, kernel_init=self.kernel_init, use_bias=self.use_bias)(
                    x
                )
                x = self.activation(x)
            else:
                x = nn.Dense(
                    dim,
                    dtype=self.dtype,
                    kernel_init=self.kernel_init,
                    use_bias=self.use_bias,
                )(x)
                x = self.activation(x)

        last_layer_kernel_init = (
            nn.initializers.zeros_init()
            if self.init_last_layer_to_zero
            else nn.initializers.lecun_normal()  # output should have variance 1
        )

        x = nn.Dense(
            self.dims[-1],
            kernel_init=last_layer_kernel_init,
            dtype=self.dtype,
            use_bias=self.last_layer_use_bias,
        )(x)
        return x


class DoubleMLPWithCrossConnections(nn.Module):
    dims: Sequence[int]
    cross_connections: Sequence[int]
    activation: Callable[[jax.Array], jax.Array]
    init_last_layer_to_zero: bool = False
    apply_activation_to_output: bool = False
    last_layer_use_bias: bool = True

    @nn.compact
    def __call__(self, x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
        for i, dim in enumerate(self.dims[:-1]):
            x = jnp.concatenate([x, y], axis=-1) if -i in self.cross_connections else x
            x = nn.Dense(dim)(x)
            x = self.activation(x)
            y = jnp.concatenate([x, y], axis=-1) if i in self.cross_connections else y
            y = nn.Dense(dim)(y)
            y = self.activation(y)

        last_layer_kernel_init = (
            nn.initializers.zeros_init()
            if self.init_last_layer_to_zero
            else nn.initializers.lecun_normal()
        )

        x = nn.Dense(
            self.dims[-1],
            kernel_init=last_layer_kernel_init,
            use_bias=self.last_layer_use_bias,
        )(x)
        y = nn.Dense(
            self.dims[-1],
            kernel_init=last_layer_kernel_init,
            use_bias=self.last_layer_use_bias,
        )(y)
        if self.apply_activation_to_output:
            x = self.activation(x)
            y = self.activation(y)
        return x, y


class FeatureMLP(nn.Module):
    """
    A simple multi-layer perceptron to act on (semi)local electron density features
    and other NN applied on the quadrature grid.

    The default `__call__` method runs the full network: input layer → activation → MLP.
    For more control (e.g., injecting non-local features into the hidden representation),
    you can separately call `_local_input` to get the hidden representation, then apply
    `activation` and `mlp` manually:

        x = net._local_input(features)
        x = x + non_local_features  # inject non-local features
        x = activation(x)
        x = net.mlp(x)
    """

    n_layers: int
    hidden_dim: int
    activation: Callable[[jax.Array], jax.Array]
    init_last_layer_to_zero: bool
    add_skip_connection: bool = False
    output_dim: int = 1
    last_layer_use_bias: bool = True
    concatenate: bool = False
    kernel_init: Initializer = nn.initializers.lecun_normal()

    def setup(self):
        self.local_input_layer = nn.Dense(
            self.hidden_dim, dtype=PRECISION.local_nn, kernel_init=self.kernel_init
        )

        dims = [self.hidden_dim] * (self.n_layers - 2) + [self.output_dim]
        self.mlp = MLP(
            dims,
            self.activation,
            PRECISION.local_nn,
            add_skip_connection=self.add_skip_connection,
            init_last_layer_to_zero=self.init_last_layer_to_zero,
            last_layer_use_bias=self.last_layer_use_bias,
            kernel_init=self.kernel_init,
        )

    def __call__(self, *feats: FloatN) -> FloatN:
        x = self._local_input(*feats)
        x = self.activation(x)
        return self.mlp(x)

    def _local_input(self, *feats: FloatN) -> FloatN:
        if self.concatenate:  # along existing feature axis
            x = jnp.concatenate(feats, axis=-1)
        else:  # along new feature axis
            x = jnp.stack(feats, axis=-1)
        return self.local_input_layer(x)
