import torch
import torch.nn as nn
import torch.nn.functional as F
class PairwiseBCELoss(nn.Module):
    '''
    Aka RankNetLoss
    '''
    def __init__(self):
        super().__init__()

    def forward(self, pred1, pred2, score1, score2):
        label = (score1 > score2).float()  # 1 if pred1 should be higher
        pred_diff = pred1 - pred2
        return F.binary_cross_entropy_with_logits(pred_diff, label)