import torch
import torch.nn as nn
import torch.nn.functional as F


class DummyLoss:
    def __init__(self):
        self.mse = nn.MSELoss()

    def __call__(self, v1, v2, gamma):
        return self.mse(v1, v2)


class MSEContrastiveLoss:
    def __init__(self, margin=1000.0):
        self.margin = margin
        self.mse = nn.MSELoss(reduction="none")

    def __call__(self, v1, v2, gamma):
        sim = torch.sum(self.mse(v1, v2), dim=1)
        loss = ((1 - gamma) * sim) + (
            gamma * torch.maximum(torch.tensor(0.0), self.margin - sim) ** 2
        )
        return loss.sum() / v1.shape[0]


class KLDivContrastiveLoss:
    def __init__(self):
        self.kl_loss = nn.KLDivLoss(reduction="none", log_target=True)

    def __call__(self, v1, v2, gamma):
        epsilon = 1e-10
        lv1 = F.log_softmax(v1, dim=1)
        lv2 = F.log_softmax(v2, dim=1).detach()
        num_classes = lv1.shape[1] - 1
        class_prob = (1 / num_classes) - epsilon
        pred_class = torch.argmax(v2, dim=1)
        unif = torch.ones_like(v1) * class_prob
        unif[torch.arange(unif.size(0)), pred_class] = epsilon
        # num_classes = lv1.shape[1]
        # class_prob = 1 / num_classes
        # unif = torch.ones_like(v1) * class_prob
        unif = torch.log(unif).detach()
        retain_loss = self.kl_loss(lv1, lv2).sum(dim=1)
        forget_loss = self.kl_loss(lv1, unif).sum(dim=1)
        loss = ((1 - gamma) * retain_loss) + (gamma * forget_loss)
        return loss.sum() / v1.shape[0]
