# Flax
import jax.numpy as jnp
from flax import nnx
from typing import Callable, Union, Dict
from nnx_models.utils import PinnLayer, FourierLinear
# Modified MLP version based on the state-of-the-art practicies in PINN training:
# Fourier embeddings and random weight factorization
# You can read more about it in the paper: https://arxiv.org/pdf/2210.01274


class MLP_PINN(nnx.Module):
    hidden_dim: int
    output_dim: int
    num_layers: int
    act: Callable = nnx.silu
    dtype: jnp.dtype = jnp.float32
    reparam : Union[None, Dict] = None
    fourier_emb : Union[None, Dict] = None

    def __init__(self, input_dim: int,
                 output_dim: int,
                 hidden_dim: int = 64,
                 num_hidden_layers: int = 2,
                 act: Callable = nnx.silu,
                 dtype: jnp.dtype = jnp.float32,
                 reparam: Union[None, Dict] = None,
                 fourier_emb_scale: Union[None, float] = None,
                 rngs: nnx.Rngs = nnx.Rngs(0)):
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_hidden_layers = num_hidden_layers
        self.act = act
        if fourier_emb_scale is not None:
            self.in_layer = FourierLinear(
                input_dim=input_dim,
                output_dim=hidden_dim,
                embed_scale=fourier_emb_scale,
                dtype=dtype,
                rngs=rngs
            )
        else:
            self.in_layer = nnx.Linear(
                in_features=input_dim,
                out_features=hidden_dim,
                use_bias=True,
                param_dtype=dtype,
                rngs=rngs
            )

        self.out_layer = nnx.Linear(
            in_features=hidden_dim,
            out_features=output_dim,
            use_bias=True,
            param_dtype=dtype,
            rngs=rngs
        )

        self.hidden_layers = []
        for _ in range(num_hidden_layers):
            self.hidden_layers.append(
                PinnLayer(
                    input_dim=hidden_dim,
                    output_dim=hidden_dim,
                    kernel_init=nnx.initializers.glorot_normal(),
                    bias_init=nnx.initializers.zeros,
                    reparam=reparam,
                    dtype=dtype,
                    rngs=rngs
                )
            )
            self.hidden_layers.append(self.act)

        self.hidden_layers = nnx.Sequential(*self.hidden_layers)

    def __call__(self, x):
        x = self.in_layer(x)
        x = jnp.concatenate([
            jnp.cos(x), jnp.sin(x)
        ], axis=-1)
        x = self.hidden_layers(x)
        x = self.out_layer(x)
        return x

if __name__ == "__main__":
    reparam = {
        "type": "weight_fact",
        "mean": 0.0,
        "stddev": 0.1
    }
    model = MLP_PINN(input_dim=2, output_dim=3, hidden_dim=64, num_hidden_layers=3, reparam=reparam, fourier_emb_scale=2.0, rngs=nnx.Rngs(0))
    x = jnp.ones((10, 2))  # Example input
    y = model(x)
    params = nnx.state(model, nnx.Param)  # Print model parameters
    print(params)
    # print(params["hidden_layers"]["layers"][0]["kernel"].value.shape)
    print(y.shape)  # Should print (10, 3)
