from typing import Protocol

import torch
import torch.nn.functional as F
from torch import nn


class ConstraintLayer(Protocol):
    """Protocol for constraint layers that can be applied to embeddings."""

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the constraint to the input tensor.

        Args:
            x: Input tensor of shape (batch_size, embedding_dim) or
               (num_entities/relations, embedding_dim)

        Returns:
            Tensor of the same shape as input with constraints applied

        """
        ...


class L2Normalization(nn.Module):
    """Constrains embeddings to have unit L2 norm."""

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize vectors to unit L2 norm.

        Args:
            x: Input tensor of shape (*, embedding_dim)

        Returns:
            L2 normalized tensor of the same shape

        """
        return F.normalize(x, p=2, dim=-1)


Dropout = nn.Dropout
# Dropout layer. Directly uses torch.nn.Dropout as it implements the ConstraintLayer protocol.


class Identity:
    """Identity constraint that returns input unchanged."""

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Return input tensor unchanged.

        Args:
            x: Input tensor

        Returns:
            Same tensor without modifications

        """
        return x
