from typing import Callable, Optional, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp


def default_init(scale: Optional[float] = jnp.sqrt(2)):  # noqa
    return nn.initializers.orthogonal(scale)


def largescale_init(scale: Optional[float] = None):  # noqa
    return nn.initializers.constant(10e6)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable = nn.relu
    activate_final: Optional[bool] = False
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(self, x: jax.Array, training: bool = False) -> jax.Array:

        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=default_init())(x)
            # x = nn.Dense(size, kernel_init=largescale_init())(x)

            if i + 1 < len(self.hidden_dims) or self.activate_final:

                if self.dropout_rate is not None:
                    x = nn.Dropout(rate=self.dropout_rate)(
                        x, deterministic=not training
                    )

                if self.layernorm:
                    x = nn.LayerNorm()(x)

                x = self.activations(x)

        return x
