import torch
from .losses import Loss


class TripletMarginSP(Loss):
    def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs):
        super(TripletMarginSP, self).__init__()
        self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none')
        self.margin = margin

    def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs):
        neg_mask = 1. - pos_mask

        num_pos = pos_mask.to(torch.long).sum(dim=1)
        num_neg = neg_mask.to(torch.long).sum(dim=1)

        dist = torch.cdist(anchor, sample, p=2)  # [num_anchors, num_samples]

        pos_dist = pos_mask * dist
        neg_dist = neg_mask * dist

        pos_dist, neg_dist = pos_dist.sum(dim=1), neg_dist.sum(dim=1)

        loss = pos_dist / num_pos - neg_dist / num_neg + self.margin
        loss = torch.where(loss > 0, loss, torch.zeros_like(loss))

        return loss.mean()


class TripletMargin(Loss):
    def __init__(self, margin: float = 1.0, p: float = 2, *args, **kwargs):
        super(TripletMargin, self).__init__()
        self.loss_fn = torch.nn.TripletMarginLoss(margin=margin, p=p, reduction='none')
        self.margin = margin

    def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs):
        num_anchors = anchor.size()[0]
        num_samples = sample.size()[0]

        # Key idea here:
        #  (1) Use all possible triples (will be num_anchors * num_positives * num_negatives triples in total)
        #  (2) Use PyTorch's TripletMarginLoss to compute the marginal loss for each triple
        #  (3) Since TripletMarginLoss accepts input tensors of shape (B, D), where B is the batch size,
        #        we have to manually construct all triples and flatten them as an input tensor in the
        #        shape of (num_triples, D).
        #  (4) We first compute loss for all triples (including those that are not anchor - positive - negative), which
        #        will be num_anchors * num_samples * num_samples triples, and then filter them with masks.

        # compute negative mask
        neg_mask = 1. - pos_mask if neg_mask is None else neg_mask

        anchor = torch.unsqueeze(anchor, dim=1)  # [N, 1, D]
        anchor = torch.unsqueeze(anchor, dim=1)  # [N, 1, 1, D]
        anchor = anchor.expand(-1, num_samples, num_samples, -1)  # [N, M, M, D]
        anchor = torch.flatten(anchor, end_dim=1)  # [N * M * M, D]

        pos_sample = torch.unsqueeze(sample, dim=0)  # [1, M, D]
        pos_sample = torch.unsqueeze(pos_sample, dim=2)  # [1, M, 1, D]
        pos_sample = pos_sample.expand(num_anchors, -1, num_samples, -1)  # [N, M, M, D]
        pos_sample = torch.flatten(pos_sample, end_dim=1)  # [N * M * M, D]

        neg_sample = torch.unsqueeze(sample, dim=0)  # [1, M, D]
        neg_sample = torch.unsqueeze(neg_sample, dim=0)  # [1, 1, M, D]
        neg_sample = neg_sample.expand(num_anchors, -1, num_samples, -1)  # [N, M, M, D]
        neg_sample = torch.flatten(neg_sample, end_dim=1)  # [N * M * M, D]

        loss = self.loss_fn(anchor, pos_sample, neg_sample)  # [N, M, M]
        loss = loss.view(num_anchors, num_samples, num_samples)

        pos_mask1 = torch.unsqueeze(pos_mask, dim=2)  # [N, M, 1]
        pos_mask1 = pos_mask1.expand(-1, -1, num_samples)  # [N, M, M]
        neg_mask1 = torch.unsqueeze(neg_mask, dim=1)  # [N, 1, M]
        neg_mask1 = neg_mask1.expand(-1, num_samples, -1)  # [N, M, M]

        pair_mask = pos_mask1 * neg_mask1  # [N, M, M]
        num_pairs = pair_mask.sum()

        loss = loss * pair_mask
        loss = loss.sum()

        return loss / num_pairs
