import torch
from einops import rearrange, reduce
from torch import nn

from .base import ParametricFusingFunction


class TransformerEncoder(ParametricFusingFunction):
    def __init__(
        self,
        dimension: int,
        num_layers: int = 2,
        num_heads: int = 8,
        dropout: float = 0.1,
        dim_feedforward: int = 2048,
    ):
        """Initialize the module.

        Args:
            dimension (int): The input dimension.
            num_layers (int, optional): The number of Transformer layers.
            num_heads (int, optional): The number of self-attention heads inside each transformer
                encoder layer.
            dropout (float, optional): The dropout rate on each transformer encoder layer.
            dim_feedforward (int, optional): The hidden dimension of the feed-forward layers of the
                transformer encoder layer.

        """
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=dimension,
                nhead=num_heads,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
            ),
            num_layers=num_layers,
        )
        self.final = nn.Linear(dimension, dimension, bias=True)

    def forward(self, s: torch.Tensor, r: torch.Tensor, **kwargs) -> torch.Tensor:
        """Evaluate the interaction function.

        Args:
            s, r (torch.Tensor): The input tensors with shape `(B, f)`, where
                `B` is the batch size,
                `f` is the feature dimension.

        Returns:
            torch.Tensor: The fused representation with shape `(B, f)`.

        """
        xs = torch.cat([s, r], dim=1)
        xs = rearrange(xs, "b two f -> two b f")
        xs = self.transformer(src=xs)
        x = reduce(xs, "two b f -> b f", "mean")
        return self.final(x)
