import torch
from torch.nn import BCELoss


class TripletLoss:
    def __init__(self, margin):# margin  用于控制正样本和负样本之间的距离差异
        if not isinstance(margin, float):
            raise TypeError("margin should be a float")
        self._margin = margin

    def __call__(self, y_pred, y_true):#
        z0 = torch.zeros_like(y_pred)
        pos_avg = torch.sum(torch.where(y_true == 1, y_pred, z0), dim=-1)#正样本和
        # avoid nan
        denom1 = torch.sum(y_true, dim=-1)#正样本数量
        pos_avg /= torch.max(torch.ones_like(denom1), denom1)#归一化正样本平均值 避免分母为0
        neg_avg = torch.sum(torch.where(y_true == 0, y_pred, z0), dim=-1)#负样本和
        # avoid nan
        denom2 = torch.sum(1 - y_true, dim=-1)#负样本数
        neg_avg /= torch.max(torch.ones_like(denom2), denom2)#归一化负样本平均值
        loss = torch.max(torch.zeros_like(pos_avg),
                         #计算三元组损失，它是正样本相似性得分和负样本相似性得分之差再加上间隔 self._margin。如果这个差值小于零，
                         # 将损失设为零；
                         # 否则，保留这个差值作为损失。这个操作确保了正样本的相似性得分高于负样本得分至少 self._margin
                         neg_avg - pos_avg + self._margin)
        # loss = max(0, pos_dist - neg_dist + margin)
        #      = max(0, neg_sims - pos_sims + margin
        return torch.mean(loss)#所有样本的损失的平均值


class CrossEntropyLoss:#二进制交叉熵损失
    def __init__(self):
        self._func = BCELoss()

    def __call__(self, y_pred, y_true):
        return self._func(y_pred, y_true.float())
