import torch
import torch.nn.functional as F
from .losses import Loss


class BootstrapLatent(Loss):
    def __init__(self):
        super(BootstrapLatent, self).__init__()

    def compute(self, anchor, sample, pos_mask, neg_mask=None, *args, **kwargs) -> torch.FloatTensor:
        anchor = F.normalize(anchor, dim=-1, p=2)
        sample = F.normalize(sample, dim=-1, p=2)

        similarity = anchor @ sample.t()
        loss = (similarity * pos_mask).sum(dim=-1)
        return loss.mean()
