from enum import Enum
from typing import Protocol

import torch
from torch.nn.functional import binary_cross_entropy_with_logits


class LossFunctionType(Protocol):
    """
    Protocol for loss functions.

    The __call__ method should take two arguments: pred and target. The pred is the model's output, and the target is
    the ground truth. The return value should be a scalar tensor representing the loss. If the target is a scalar, it
    should be broadcasted to the same shape as the pred. The pred should take a value in [0, 1].
    """

    def __call__(self, pred: torch.Tensor, target: torch.Tensor | int) -> torch.Tensor: ...


def cross_entropy_loss(pred, target):
    if type(target) is int:
        target = torch.tensor([target] * len(pred), dtype=torch.float32, device=pred.device)
    return binary_cross_entropy_with_logits(pred, target)


def sigmoid_loss(pred, target):
    return torch.sigmoid(-pred * (2 * target - 1)).mean()


def zero_one_loss(pred: torch.Tensor, target: torch.Tensor | int, aggregate: bool = True) -> torch.Tensor:
    if type(target) is int:
        target = torch.tensor([target] * len(pred), dtype=torch.float32, device=pred.device)
    """Returns zero-one loss."""
    if aggregate:
        return 1 - torch.mean(((pred * (2 * target - 1)) > 0).float())
    else:
        return 1 - (pred * (2 * target - 1) > 0).float()


class LossFunction(Enum):
    CROSS_ENTROPY = "cross_entropy"
    SIGMOID = "sigmoid"
    ZERO_ONE = "zero_one"

    def __call__(self, pred, target):
        match self:
            case LossFunction.CROSS_ENTROPY:
                return cross_entropy_loss(pred, target)
            case LossFunction.SIGMOID:
                return sigmoid_loss(pred, target)
            case LossFunction.ZERO_ONE:
                return zero_one_loss(pred, target)
            case _:
                raise ValueError("Invalid loss function")
