from collections.abc import Sequence

import keras
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.decorators import sanitize_input_shape

from .equivariant_module import EquivariantModule
from .invariant_module import InvariantModule
from ..summary_network import SummaryNetwork


@serializable(package="bayesflow.networks")
class DeepSet(SummaryNetwork):
    """Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
    set-based data, as generated by exchangeable models.

    [1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017).
    Deep sets. Advances in neural information processing systems, 30.
    """

    def __init__(
        self,
        summary_dim: int = 16,
        depth: int = 2,
        inner_pooling: str = "mean",
        output_pooling: str = "mean",
        mlp_widths_equivariant: Sequence[int] = (64, 64),
        mlp_widths_invariant_inner: Sequence[int] = (64, 64),
        mlp_widths_invariant_outer: Sequence[int] = (64, 64),
        mlp_widths_invariant_last: Sequence[int] = (64, 64),
        activation: str = "gelu",
        kernel_initializer: str = "he_normal",
        dropout: int | float | None = 0.05,
        spectral_normalization: bool = False,
        **kwargs,
    ):
        """
        Initializes a fully customizable deep learning model for learning permutation-invariant representations of
        sets (i.e., exchangeable or IID data). Do not use this model for non-IID data (e.g., time series).

        Important: Prefer a SetTransformer to a DeepSet, especially is the simulation budget is high.

        The model consists of multiple stacked equivariant transformation modules followed by an invariant pooling
        module to produce a compact set representation.

        The equivariant layers perform many-to-many transformations, preserving structural information, while
        the final invariant module aggregates the set into a lower-dimensional summary.

        The model supports various activation functions, kernel initializations, and optional spectral normalization
        for stability. Pooling mechanisms can be specified for both intermediate and final aggregation steps.

        Parameters
        ----------
        summary_dim : int, optional
            Dimensionality of the final learned summary statistics. Default is 16.
        depth : int, optional
            Number of stacked equivariant modules. Default is 2.
        inner_pooling : str, optional
            Type of pooling operation applied within equivariant modules, such as "mean".
            Default is "mean".
        output_pooling : str, optional
            Type of pooling operation applied in the final invariant module, such as "mean".
            Default is "mean".
        mlp_widths_equivariant : Sequence[int], optional
            Widths of the MLP layers inside the equivariant modules. Default is (64, 64).
        mlp_widths_invariant_inner : Sequence[int], optional
            Widths of the inner MLP layers within the invariant module. Default is (64, 64).
        mlp_widths_invariant_outer : Sequence[int], optional
            Widths of the outer MLP layers within the invariant module. Default is (64, 64).
        mlp_widths_invariant_last : Sequence[int], optional
            Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
        activation : str, optional
            Activation function used throughout the network, such as "gelu". Default is "gelu".
        kernel_initializer : str, optional
            Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
        dropout : int, float, or None, optional
            Dropout rate applied within MLP layers. Default is 0.05.
        spectral_normalization : bool, optional
            Whether to apply spectral normalization to stabilize training. Default is False.
        **kwargs
            Additional keyword arguments passed to the equivariant and invariant modules.
        """

        super().__init__(**kwargs)

        # Stack of equivariant modules for a many-to-many learnable transformation
        self.equivariant_modules = []
        for _ in range(depth):
            equivariant_module = EquivariantModule(
                mlp_widths_equivariant=mlp_widths_equivariant,
                mlp_widths_invariant_inner=mlp_widths_invariant_inner,
                mlp_widths_invariant_outer=mlp_widths_invariant_outer,
                activation=activation,
                kernel_initializer=kernel_initializer,
                spectral_normalization=spectral_normalization,
                dropout=dropout,
                pooling=inner_pooling,
                **filter_kwargs(kwargs, EquivariantModule),
            )
            self.equivariant_modules.append(equivariant_module)

        # Invariant module for a many-to-one transformation
        self.invariant_module = InvariantModule(
            mlp_widths_inner=mlp_widths_invariant_last,
            mlp_widths_outer=mlp_widths_invariant_last,
            activation=activation,
            kernel_initializer=kernel_initializer,
            dropout=dropout,
            pooling=output_pooling,
            spectral_normalization=spectral_normalization,
            **filter_kwargs(kwargs, InvariantModule),
        )

        # Output linear layer to project set representation down to "summary_dim" learned summary statistics
        self.output_projector = keras.layers.Dense(summary_dim, activation="linear")
        self.summary_dim = summary_dim

    @sanitize_input_shape
    def build(self, input_shape):
        super().build(input_shape)
        self.call(keras.ops.zeros(input_shape))

    def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
        """
        Performs the forward pass of a hierarchical deep invariant transformation.

        This function applies a sequence of equivariant transformations to the input tensor,
        preserving structural relationships while refining representations. After passing
        through the equivariant modules, the data is processed by an invariant transformation,
        which aggregates information into a lower-dimensional representation. The final output
        is projected to the specified summary dimension using a linear layer.

        Parameters
        ----------
        x : Tensor
            Input tensor representing a set or collection of elements to be transformed.
        training : bool, optional
            Whether the model is in training mode, affecting layers like dropout. Default is False.
        **kwargs
            Additional keyword arguments passed to the transformation layers.

        Returns
        -------
        output : Tensor
            Transformed tensor with a reduced dimensionality, representing the learned summary
            of the input set.
        """

        for em in self.equivariant_modules:
            x = em(x, training=training)

        x = self.invariant_module(x, training=training)

        return self.output_projector(x)
