import torch
from omegaconf import DictConfig
from pado.core import PadoModule

__all__ = ["PadoLoss"]


class PadoLoss(PadoModule):
    """
    Base class of losses.
    """

    def __init__(self, *, reduction: str = "mean"):
        super().__init__()

        reduction = reduction.lower()
        if reduction not in ("none", "mean", "sum", "batch_mean", None):
            raise ValueError(f"Invalid loss reduction {reduction}.")

        if reduction is None:
            reduction = "none"
        self.reduction = reduction

    def __call__(self, *args, **kwargs) -> torch.Tensor:
        return self.forward(*args, **kwargs)

    def forward(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig):
        raise NotImplementedError
