import flax.linen as nn
import jax.numpy as jnp

from .layer import BoxDense
from typing import Callable

class GatingNetwork(nn.Module):
    """
    A feedforward neural network that acts as a gating mechanism for partitioning input space.

    This network takes an input and produces a probability distribution over partitions.
    The output is softmax-normalized to ensure the probabilities sum to 1.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    num_partitions: int
    activation: callable
    embedding: callable
    dtype: any = jnp.float64

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain", dtype=self.dtype)
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet", dtype=self.dtype) for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.num_partitions, lambda x: x, self.n_layers, self.n_layers, "plain", dtype=self.dtype)


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x)
        return nn.softmax(self.output(x))

