from typing import Any

import torch
from einops import rearrange

from .base import FusingFunction


class RESCAL(FusingFunction):
    """RESCAL fusing function that performs matrix multiplication between subject and relation."""

    def __call__(self, s: torch.Tensor, r: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Merge subject and relation embeddings using RESCAL's matrix multiplication approach.

        Args:
            s (torch.Tensor): Subject entity embeddings with shape (B, f)
            r (torch.Tensor): Relation embeddings with shape (B, f*f), representing flattened matrices

        Returns:
            torch.Tensor: Fused representation with shape (B, f)

        """
        dim = s.size(-1)
        r_matrices = rearrange(r, "b (d1 d2) -> b d1 d2", d1=dim, d2=dim)
        # (B, 1, f) @ (B, f, f) -> (B, 1, f)
        fused = torch.bmm(s.unsqueeze(1), r_matrices)
        # (B, 1, f) -> (B, f)
        return fused.squeeze(1)
