# forward_forward/losses.py

import torch
import torch.nn.functional as F

def ff_margin_loss(goodness_pos, goodness_neg, tau=2.0):
    """Margin-based separation loss."""
    pos_term = torch.log1p(torch.exp(tau - goodness_pos))
    neg_term = torch.log1p(torch.exp(goodness_neg - tau))
    loss = pos_term + neg_term
    return loss.mean()

def ff_symba_loss(goodness_pos, goodness_neg):
    """SymBa loss."""
    gap = goodness_pos - goodness_neg
    loss = torch.log1p(torch.exp(gap))
    return loss.mean()

# Add margin enforcement to improve separation
def ff_bce_loss(logits, target, margin=0):
    """Binary Cross Entropy loss with margin enforcement."""
    if margin < 0:
        raise ValueError("Margin must be non-negative.")
    else:
        # Apply margin to target
        target = target * (1 - margin) + margin / 2  
    return F.binary_cross_entropy_with_logits(logits, target)
