import jax
import jax.numpy as jnp
from jax import random
from jax.nn import initializers

import inspect
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
from typing import Callable, Tuple

"""an MLP model that consists of multiple dense layers followed by an output layer"""


class MLP(nn.Module):
    """Multi-Layer Perceptron (MLP) neural network model.

    This class represents an MLP model that consists of multiple dense layers followed by an output layer.
    The model can be customized with various parameters such as activation function, hidden layer sizes,
    output size, initialization functions, and more.

    Args:
        act_fn: Activation function to be used in the dense layers. Default is nn.relu.
        hidden_sizes: Tuple of integers representing the sizes of the hidden layers. Default is (64).
        output_size: Integer representing the size of the output layer. Default is 14.
        use_bias: Whether to include bias terms in the dense layers. Default is True.
        init_scale: Scale factor for initialization. Default is 0.0001.
        kernel_init: Initialization function for the kernel weights. Default is initializers.normal.
        bias_init: Initialization function for the bias terms. Default is initializers.normal.

    Methods:
        custom_initialisation: Custom initialization function that can be used for both weights and biases.
        setup: Setup function for initializing the MLP.
        __call__: Forward pass function of the MLP.

    Returns:
        The output of the MLP.
    """
    act_fn: Callable = nn.relu
    hidden_sizes: Tuple[int, ...] = (64,)
    output_size: int = 14
    use_bias: bool = True
    use_bias_output: bool = False
    init_scale: float = 0.0002
    kernel_init: Callable = initializers.normal
    bias_init: Callable = initializers.normal

    def setup(self):
        self.dense_layers = [nn.Dense(size, 
                                      use_bias=self.use_bias, 
                                      kernel_init=self.custom_initialisation(self.kernel_init),
                                      bias_init=self.custom_initialisation(self.bias_init)) 
                             for size in self.hidden_sizes]

        self.output_layer = nn.Dense(self.output_size,
                                     use_bias=self.use_bias_output, 
                                     kernel_init=self.custom_initialisation(self.kernel_init),
                                     bias_init=self.custom_initialisation(self.bias_init))

    def __call__(self, x):
        for dense in self.dense_layers:
            x = dense(x)
            x = self.act_fn(x) if self.act_fn is not None else x # pylint: disable=too-many-function-args
        x = self.output_layer(x)
        return x
    
    def custom_initialisation(self, init_fn, *args, **kwargs):
        """Custom initialization function. that can be used for both weights and biases.

    Args:
        init_fn: The initialization function to be called.
        *args: Positional arguments to be passed to the initialization function.
        **kwargs: Keyword arguments to be passed to the initialization function.
    Returns:
        The result of the initialization function.
        """
        # Check if 'stddev' is a valid parameter for the initializer
        if 'stddev' in inspect.signature(init_fn).parameters:
            return init_fn(*args, **kwargs, stddev=self.init_scale)
        else:
            return init_fn(*args, **kwargs)
    
if __name__ == "__main__":
    key = random.PRNGKey(0)
    input_shape = (1, 10)
    model = MLP()
    params = model.init(key, jnp.ones(input_shape))
    print(model)
    print(model.apply(params, jnp.ones(input_shape)))
    print(params)