from typing import Callable, Literal

import flax.linen as nn
import jax

from egxc.utils.typing import PRECISION

from .layers import ScaledSigmoid
from .mlp import FeatureMLP


class SpatialReweighting(nn.Module):
    """
    A spatial reweighting network that learns the importance of different
    a.k.a. a non-local "enhancement factor" to the base energy-density.
    """

    layers: int
    hidden_dim: int
    activation: Callable[[jax.Array], jax.Array] = nn.silu
    output_activation_type: Literal['scaled_sigmoid', 'None'] = 'None'
    init_mean: float = 1.0
    init_scale: float = 0.1

    def setup(self):
        self.mean = self.param(
            'mean', jax.nn.initializers.constant(self.init_mean), (), PRECISION.local_nn
        )
        self.net = FeatureMLP(
            self.layers,
            self.hidden_dim,
            self.activation,
            concatenate=True,
            last_layer_use_bias=False,
            init_last_layer_to_zero=True,
            kernel_init=nn.initializers.variance_scaling(
                2.2,
                'fan_in',
                'uniform',
            ),
        )
        if self.output_activation_type == 'scaled_sigmoid':
            self.output_activation = ScaledSigmoid(
                initial_scale=self.init_scale,
                constant_y_offset=-self.init_scale / 2.0,
            )
        elif self.output_activation_type == 'None':
            self.scale = self.param(
                'scale',
                jax.nn.initializers.constant(self.init_scale),
                (),
                PRECISION.local_nn,
            )
        else:
            raise ValueError(f'Invalid output activation: {self.output_activation_type}')

    def __call__(self, x: jax.Array) -> jax.Array:
        out = self.mean
        raw_out = self.net(x)
        if self.output_activation_type == 'None':
            out += self.scale * raw_out
        else:
            out += self.output_activation(raw_out)
        return out.squeeze(axis=-1)
