from typing import Protocol

import torch
from torch.nn import BCELoss, BCEWithLogitsLoss
from typing_extensions import runtime_checkable


@runtime_checkable
class BCELossProtocol(Protocol):
    """Protocol for binary cross entropy loss functions."""

    def __call__(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate binary cross entropy loss between predictions and targets.

        Args:
            predictions: Predicted scores (should be between 0 and 1)
            targets: Target values (0 or 1)

        Returns:
            Loss value

        """
        ...


@runtime_checkable
class BCEWithLogitsLossProtocol(Protocol):
    """Protocol for binary cross entropy loss functions that accept logits."""

    def __call__(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate binary cross entropy loss between logits and targets.

        Args:
            predictions: Predicted logits (unbounded)
            targets: Target values (0 or 1)

        Returns:
            Loss value

        """
        ...


# Runtime protocol verification
_: BCELossProtocol = BCELoss()
_: BCEWithLogitsLossProtocol = BCEWithLogitsLoss()


class Model:
    def __init__(self):
        self.loss_fn: BCELossProtocol = BCELoss()

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.loss_fn(x, torch.zeros_like(x))
