import torch
import torch.nn.functional as F


def clip_bce(output_dict, target_dict):
    """Binary crossentropy loss."""
    return F.binary_cross_entropy(output_dict["clipwise_output"], target_dict["target"])


def get_loss_func(loss_type):
    if loss_type == "clip_bce":
        return clip_bce
