from lore.f.base import f
import torch
from math import comb


class LogisticTripletLoss(f):
    def __init__(self, triplets, margin):
        super().__init__()
        self.triplets = triplets.long()
        self.margin = margin

    def forward(self, embedding, eps=1e-3):
        X = embedding[self.triplets]
        anchor, positive, negative = X[:, 0, :], X[:, 1, :], X[:, 2, :]
        # Write this yourself (the functional one uses max(0, -) instead of log(1 + exp(-))
        dist_ap = torch.norm(anchor - positive, p=2, dim=1) + eps
        dist_an = torch.norm(anchor - negative, p=2, dim=1) + eps
        triplet_loss = dist_ap - dist_an + self.margin
        batch_loss = torch.log(1 + torch.exp(triplet_loss))
        return batch_loss.mean()