import jax.numpy as jnp
import flax.linen as nn
import orthax
from jax import vmap
from typing import Callable


from .layer import BoxDense


class PNets(nn.Module):
    """
    A uniform interface for polynomial basis construction that delegates
    to either ClassicalBasis (polynomial expansions) or MLPBasis (neural expansions).
    """
    basis_choice: str # 'classical' or 'mlp'
    # Default Params
    ## ClassicalBasis parameters
    poly_type: str = 'chebyshev'
    basis_size: int = 4
    num_partitions: int = 2
    ## MLPBasis parameters
    n_hidden: int = 16
    n_layers: int = 2
    activation: Callable = nn.tanh
    dtype: any = jnp.float64

    def setup(self):
        if self.basis_choice == 'classical':
            self.basis_impl = ClassicalBasis(
                poly_type=self.poly_type,
                basis_size=self.basis_size,
                num_partitions=self.num_partitions
            )
        elif self.basis_choice == 'mlp':
            self.basis_impl = MLPBasis(
                n_hidden=self.n_hidden,
                n_layers=self.n_layers,
                basis_size=self.basis_size,
                activation=self.activation,
                dtype=self.dtype
            )
        else:
            raise ValueError(f"Unsupported basis choice: {self.basis_choice}")

    def __call__(self, x):
        return self.basis_impl(x)


class MLPBasis(nn.Module):
    """
    A network that returns a group of MLP basis functions, one for each partition.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    basis_size: int
    activation: Callable
    dtype: any = jnp.float64

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation,
                                depth=self.n_layers, layer=0, arch_type="plain", dtype=self.dtype)
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet", self.dtype)\
            for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.basis_size, lambda x: x,
                               self.n_layers, self.n_layers, "plain", self.dtype)


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        #x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x)
        return self.output(x)


class ClassicalBasis(nn.Module):
    """
    A network that returns a group of polynomials, one for each partition.

    This module creates a set of polynomial functions, where each polynomial
    is associated with a partition of the input space.

    Attributes:
        poly_type (str): Type of polynomial basis to use ('monomial', 'legendre', or 'chebyshev').
        basis_size (int): Number of basis functions to use for each polynomial.
        num_partitions (int): Number of partitions (and thus number of polynomials).
    """

    poly_type: str
    basis_size: int
    num_partitions: int
    dtype: any = jnp.float64

    def setup(self):
        """
        Set up the polynomial basis coefficients and select the appropriate polynomial function.

        This method initializes the coefficients for the polynomial basis functions
        and selects the correct polynomial evaluation function based on the specified type.
        """
        self.basis_coeffs = self.param(
            "basis_coeffs",
            nn.initializers.ones_init(),
            (self.num_partitions, self.basis_size),
            self.dtype
        )
        if self.poly_type == "monomial":
            self.poly_fn = orthax.polynomial.polyval
        elif self.poly_type == "legendre":
            self.poly_fn = orthax.legendre.legval
        elif self.poly_type == "chebyshev":
            self.poly_fn = orthax.chebyshev.chebval
        else:
            raise ValueError(f"Unsupported polynomial type: {self.poly_type}")

    def __call__(self, x):
        """
        Evaluate the polynomials at the given input.

        Args:
            x (jnp.ndarray): Input tensor. Expected to be a single scalar value.

        Returns:
            jnp.ndarray: Output of shape (num_partitions,), where each element
                         is the evaluation of the corresponding partition's polynomial.
        """
        assert x.size == 1, "Input must be a scalar value"
        poly_out = vmap(self.poly_fn, in_axes=(None, 0))(x, self.basis_coeffs)
        return poly_out

