# DeepSet AttentionPooling

from collections.abc import Sequence

import keras

from bayesflow.types import Tensor
from bayesflow.utils.serialization import serializable

from bayesflow.networks.deep_set.equivariant_layer import EquivariantLayer

from bayesflow.networks.transformers.pma import PoolingByMultiHeadAttention

from bayesflow.networks import SummaryNetwork


@serializable("bayesflow.networks")
class DeepSetMHA(SummaryNetwork):
    """(SN) Implements a deep set encoder introduced in [1] for learning permutation-invariant representations of
    set-based data, as generated by exchangeable models.

    [1] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., & Smola, A. J. (2017).
    Deep sets. Advances in neural information processing systems, 30.
    """

    def __init__(
        self,
        summary_dim: int = 16,
        depth: int = 2,
        inner_pooling: str = "mean",
        mlp_widths_equivariant: Sequence[int] = (64, 64),
        mlp_widths_invariant_inner: Sequence[int] = (64, 64),
        mlp_widths_invariant_outer: Sequence[int] = (64, 64),
        activation: str = "silu",
        kernel_initializer: str = "he_normal",
        dropout: int | float | None = 0.05,
        spectral_normalization: bool = False,
        num_heads=4, mlp_depth=2, mlp_width=128, num_seeds=1,
        **kwargs,
    ):
        """
        Initializes a fully customizable deep learning model for learning permutation-invariant representations of
        sets (i.e., exchangeable or IID data). Do not use this model for non-IID data (e.g., time series).

        Important: Prefer a SetTransformer to a DeepSet, especially is the simulation budget is high.

        The model consists of multiple stacked equivariant transformation modules followed by an invariant pooling
        module to produce a compact set representation.

        The equivariant layers perform many-to-many transformations, preserving structural information, while
        the final invariant module aggregates the set into a lower-dimensional summary.

        The model supports various activation functions, kernel initializations, and optional spectral normalization
        for stability. Pooling mechanisms can be specified for both intermediate and final aggregation steps.

        Parameters
        ----------
        summary_dim : int, optional
            Dimensionality of the final learned summary statistics. Default is 16.
        depth : int, optional
            Number of stacked equivariant modules. Default is 2.
        inner_pooling : str, optional
            Type of pooling operation applied within equivariant modules, such as "mean".
            Default is "mean".
        output_pooling : str, optional
            Type of pooling operation applied in the final invariant module, such as "mean".
            Default is "mean".
        mlp_widths_equivariant : Sequence[int], optional
            Widths of the MLP layers inside the equivariant modules. Default is (64, 64).
        mlp_widths_invariant_inner : Sequence[int], optional
            Widths of the inner MLP layers within the invariant module. Default is (64, 64).
        mlp_widths_invariant_outer : Sequence[int], optional
            Widths of the outer MLP layers within the invariant module. Default is (64, 64).
        mlp_widths_invariant_last : Sequence[int], optional
            Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
        activation : str, optional
            Activation function used throughout the network, such as "gelu". Default is "silu".
        kernel_initializer : str, optional
            Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
        dropout : int, float, or None, optional
            Dropout rate applied within MLP layers. Default is 0.05.
        spectral_normalization : bool, optional
            Whether to apply spectral normalization to stabilize training. Default is False.
        **kwargs
            Additional keyword arguments passed to the equivariant and invariant modules.
        """

        super().__init__(**kwargs)

        # Stack of equivariant modules for a many-to-many learnable transformation
        self.equivariant_modules = []
        for _ in range(depth):
            equivariant_module = EquivariantLayer(
                mlp_widths_equivariant=mlp_widths_equivariant,
                mlp_widths_invariant_inner=mlp_widths_invariant_inner,
                mlp_widths_invariant_outer=mlp_widths_invariant_outer,
                activation=activation,
                kernel_initializer=kernel_initializer,
                spectral_normalization=spectral_normalization,
                dropout=dropout,
                pooling=inner_pooling,
            )
            self.equivariant_modules.append(equivariant_module)

        # Pooling
        embed_dim = mlp_widths_equivariant[-1]
        self.pooling_by_attention = PoolingByMultiHeadAttention(
            num_heads=num_heads, embed_dim=embed_dim,
            mlp_depth=mlp_depth, mlp_width=mlp_width,
            num_seeds=num_seeds, seed_dim=embed_dim,
            dropout=dropout, mlp_activation="gelu", layer_norm=True
        )

        # Output linear layer to project set representation down to "summary_dim" learned summary statistics
        self.output_projector = keras.layers.Dense(summary_dim, activation="linear", name="output_projector")

        self.summary_dim = summary_dim

    def call(self, x: Tensor, training: bool = False, **kwargs) -> Tensor:
        """
        Performs the forward pass of a hierarchical deep invariant transformation.

        This function applies a sequence of equivariant transformations to the input tensor,
        preserving structural relationships while refining representations. After passing
        through the equivariant modules, the data is processed by an invariant transformation,
        which aggregates information into a lower-dimensional representation. The final output
        is projected to the specified summary dimension using a linear layer.

        Parameters
        ----------
        x : Tensor
            Input tensor representing a set or collection of elements to be transformed.
        training : bool, optional
            Whether the model is in training mode, affecting layers like dropout. Default is False.
        **kwargs
            Additional keyword arguments passed to the transformation layers.

        Returns
        -------
        output : Tensor
            Transformed tensor with a reduced dimensionality, representing the learned summary
            of the input set.
        """

        for em in self.equivariant_modules:
            x = em(x, training=training)

        x = self.pooling_by_attention(x, training=training, **kwargs)

        return self.output_projector(x)

