from collections.abc import Callable
from typing import Final

from flax import nnx
from flax.typing import Dtype
from jax import Array
from jax.nn import relu, softplus
from jax.numpy import float32, tanh


MLP_KWARGS: Final[tuple[str, ...]] = (
    "dropout_rate",
    "hidden_features",
    "layer_norm",
    "nonlinearity",
    "num_layers",
)


def mish(x: Array) -> Array:
    return tanh(softplus(x)) * x


NONLINEARITIES: Final[dict[str, Callable[[Array], Array]]] = {
    "mish": mish,
    "relu": relu,
    "tanh": tanh,
}


class Identity(nnx.Module):
    def __call__(self, x):
        return x


class MLP(nnx.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rngs: nnx.Rngs,
        dropout: float = 0,
        hidden_features: int = 256,
        layer_norm: bool = False,
        nonlinearity: Callable[[Array], Array] | str = "relu",
        num_layers: int = 4,
        param_dtype: Dtype = float32,
    ):
        if num_layers < 2:
            raise ValueError(f"num_layers should be at least 2: {num_layers}")
        in_list = [in_features] + [hidden_features] * (num_layers - 2)
        out_list = [hidden_features] * (num_layers - 2) + [out_features]
        self.dropout = nnx.Dropout(dropout, rngs=rngs)
        self.linears = [
            nnx.Linear(f_in, f_out, param_dtype=param_dtype, rngs=rngs)
            for f_in, f_out in zip(in_list, out_list)
        ]
        self.layer_norms: list[nnx.LayerNorm | Identity]
        if layer_norm:
            self.layer_norms = [
                nnx.LayerNorm(f_out, param_dtype=param_dtype, rngs=rngs)
                for f_out in out_list[:-1]
            ]
        else:
            self.layer_norms = [Identity()] * (num_layers - 2)
        self.nonlinearity: Callable[[Array], Array] = (
            NONLINEARITIES[nonlinearity]
            if isinstance(nonlinearity, str)
            else nonlinearity
        )

    def __call__(self, x) -> Array:
        for linear, layer_norm in zip(self.linears, self.layer_norms):
            x = linear(x)
            x = layer_norm(x)
            x = self.dropout(x)
            x = self.nonlinearity(x)
        return self.linears[-1](x)


class MLPEnsemble(nnx.Module):
    def __init__(
        self,
        ensemble_size: int,
        in_features: int,
        out_features: int,
        rngs: nnx.Rngs,
        out_axis: int = 0,
        **kwargs,
    ):

        @nnx.split_rngs(splits=ensemble_size)
        @nnx.vmap
        def create_model(rngs_: nnx.Rngs):
            return MLP(
                in_features=in_features,
                out_features=out_features,
                rngs=rngs_,
                **kwargs,
            )

        self.model = create_model(rngs)
        self.forward = nnx.split_rngs(splits=ensemble_size)(
            nnx.vmap(in_axes=(0, None), out_axes=out_axis)(MLP.__call__)
        )

    def __call__(self, x) -> Array:
        return self.forward(self.model, x)
