from typing import Protocol

import torch
from torch.nn import CrossEntropyLoss
from typing_extensions import runtime_checkable


@runtime_checkable
class CELossProtocol(Protocol):
    """Protocol for cross entropy loss functions."""

    def __call__(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate cross entropy loss between logits and target indices.

        Args:
            logits: Raw prediction scores (before softmax)
            targets: Target class indices

        Returns:
            Loss value

        """
        ...


# Runtime protocol verification
_: CELossProtocol = CrossEntropyLoss()
