import torch
from torch import nn
import torch.nn.functional as F


class ExtendedEmbeddingCriterion(nn.Module):
    def __init__(self, target_embedding, euclidean=False, reduction="mean"):
        super().__init__()
        self.target = target_embedding
        self.euclidean = euclidean
        self.reduction = reduction

    def forward(self, embeds):
        assert len(embeds.shape) == 2

        if self.euclidean:
           distances = torch.cdist(self.target[None, :], embeds)
           if self.reduction == "mean":
               return distances.mean()
           elif self.reduction == "sum":
               return distances.sum()
           elif self.reduction == "min":
                return distances.min()
           else:
               raise ValueError(f"Unknown reduction {self.reduction}")

        return (1 - F.cosine_similarity(self.target[None, :], embeds)).mean()