import torch.nn.functional as F
import torch


def binary_cross_entropy(input, meta):
    target = meta['target'].reshape(input.shape)
    return F.binary_cross_entropy_with_logits(input, target)


def cross_entropy(input, meta):
    return F.cross_entropy(input, meta['target'])


def mse_loss(input, meta, reduction='mean'):
    return F.mse_loss(input, meta['target'], reduction=reduction)


def clip_loss(input, meta):
    sim = input[2]
    labels = torch.arange(sim.shape[0], device=sim.device)
    loss = 0.5 * (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels))

    return loss


def triplet_loss_l2_with_swap(input, meta, margin=0.5):
    fa = input[0]
    fb = input[1]
    bs = fa.shape[0]

    pos_dist = ((fa - fb)**2).sum(1)
    neg = torch.cat([fb[bs//2:], fb[:bs//2]], dim=0)  # flip half batch to create negatives
    neg_dist_1 = ((fa - neg)**2).sum(1)
    neg_dist_2 = ((fb - neg)**2).sum(1)
    neg_dist = torch.minimum(neg_dist_1, neg_dist_2)
    return F.relu(pos_dist - neg_dist + margin).mean()


def triplet_loss_l2(input, meta, margin=0.5):
    fa = input[0]
    fb = input[1]
    bs = fa.shape[0]
    pos_dist = ((fa - fb)**2).sum(1)

    def tl(a, b):
        neg = torch.cat([b[bs//2:], b[:bs//2]], dim=0)  # flip half batch to create negatives
        neg_dist = ((a - neg)**2).sum(1)
        return F.relu(pos_dist - neg_dist + margin).mean()

    return 0.5*(tl(fa, fb) + tl(fb, fa))


def triplet_loss_cos(input, meta, margin=0.5):
    fa = input[0]
    fb = input[1]
    bs = fa.shape[0]

    pos_dist = torch.einsum("ni, ni -> n", fa, fb)
    neg = torch.cat([fb[bs//2:], fb[:bs//2]], dim=0)  # flip half batch to create negatives
    neg_dist = torch.einsum("ni, ni -> n", fa, neg)
    return F.relu(neg_dist - pos_dist + margin).mean()
