
# factual loss evalution
def factual_loss(y_true, t_true, y0_pred, y1_pred):
    """
    计算 factual loss
    """
    y0_pred = y0_pred.squeeze()
    y1_pred = y1_pred.squeeze()

    loss0 = torch.sum((1. - t_true) * torch.square(y_true - y0_pred))
    loss1 = torch.sum(t_true * torch.square(y_true - y1_pred))

    return loss0 + loss1

