import torch
import torch.nn as nn

def all_pairs_venn_loss(z, y):
    """ z.shape: (B, B)
    """
    loss = nn.functional.binary_cross_entropy(z, y.float())
    return loss

def all_pairs_weighted_venn_loss(z, y):
    """ z.shape: (B, B)
    """
    total = y.shape[0] * y.shape[1]
    posw =  (total / y.sum()) - 1.0
    posw = torch.where(y > 0.0, posw, 1.0).to(y.device)
    loss = nn.functional.binary_cross_entropy(z, y.float(), weight=posw)
    return loss

def all_pairs_venn_loss_3c(z, y):
    """ z.shape: (B, B, 3)
    z are the unnormalized logits
    0: no intersection (test_neg)
    1: A <= B (test_pos1)
    2: B <= A (test_pos2)
    """
    loss = nn.functional.cross_entropy(
        z.reshape(-1, 3),
        y.reshape(-1)
    )
    return loss
