from abc import ABC, abstractmethod
from typing import Any

import torch
from torch import nn


class FusingFunction(ABC):
    """Base class for fusing functions."""

    @abstractmethod
    def __call__(self, s: torch.Tensor, r: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Merge multiple inputs into a single representation.

        Args:
            s (torch.Tensor): Input tensor with shape `(B, f)`, where
                `B` is the batch size,
                `f` is the feature dimension.
            r (torch.Tensor): Input tensor with shape `(B, f)`, where
                `B` is the batch size,
                `f` is the feature dimension. Usually the same as `s`, but not always (e.g. RESCAL).

        Returns:
            torch.Tensor: A tensor of shape `(B, f)` containing the fused representation.

        """


class ParametricFusingFunction(nn.Module, FusingFunction):
    """Base class for parametric fusing functions."""

    def __init__(self):
        super().__init__()

    def __call__(self, s: torch.Tensor, r: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        return super().__call__(s, r, **kwargs)

    @abstractmethod
    def forward(self, s: torch.Tensor, r: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Merge multiple inputs into a single representation.

        Args:
            s (torch.Tensor): Input tensor with shape `(B, f)`, where
                `B` is the batch size,
                `f` is the feature dimension.
            r (torch.Tensor): Input tensor with shape `(B, f)`, where
                `B` is the batch size,
                `f` is the feature dimension. Usually the same as `s`, but not always (e.g. RESCAL).

        Returns:
            torch.Tensor: A tensor of shape `(B, f)` containing the fused representation.

        """
