import torch
import torch.nn.functional as F

def polar_loss(mask, edge_gt, pos_weight=10.0):

    mask = mask.view(-1)
    edge_gt = edge_gt.view(-1).float()

    # Base binary cross entropy loss with class weighting
    bce_loss = F.binary_cross_entropy_with_logits(
        mask,
        edge_gt,
        reduction='none',
        pos_weight=torch.tensor(pos_weight, device=mask.device)
    )

    # Combine all loss components
    total_loss = bce_loss

    return total_loss.mean()