import torch
from src.losses.base_mean_field_loss import BaseMeanFieldLoss


class ClassWiseMultiSimilarityLoss(torch.nn.Module):
    def __init__(self, alpha, beta, base):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.base = base

    def get_target_mask(self, embeddings, labels):
        batch_size = labels.size(0)
        label_set = torch.unique(labels)
        mask = (labels.reshape(batch_size, 1) == label_set) * 1.0
        return mask.to(device=embeddings.device)

    def scale_embeddings(self, embeddings, eps=1e-12):
        return torch.nn.functional.normalize(embeddings, dim=1, eps=eps)

    def forward(
        self,
        embeddings,
        labels,
        indices_tuple=None,  # necessary for compatibility with pytorch-metric-learning
    ):
        dtype, device = embeddings.dtype, embeddings.device
        num_classes_in_batch = torch.unique(labels).shape[0]

        embeddings = self.scale_embeddings(
            embeddings
        )  # shape = (batch_size, embedding_size)
        target_mask = self.get_target_mask(
            embeddings=embeddings, labels=labels
        )  # shape = (batch_size, num_classes)
        mask = torch.eye(
            n=target_mask.shape[-1], device=device
        )  # shape = (num_classes, num_classes)

        target_mask_normalized = target_mask / torch.maximum(
            torch.sum(target_mask, dim=0), torch.ones(1, device=target_mask.device)
        )

        cos = embeddings @ embeddings.T  # shape = (batch_size, batch_size)
        pos_logit = torch.exp(
            -self.alpha * (cos - self.base)
        )  # shape = (batch_size, batch_size)
        neg_logit = torch.exp(
            self.beta * (cos - self.base)
        )  # shape = (batch_size, batch_size)

        pos = (
            target_mask_normalized.T @ pos_logit @ target_mask_normalized
        )  # shape = (num_classes, num_classes)
        neg = (
            target_mask_normalized.T @ neg_logit @ target_mask_normalized
        )  # shape = (num_classes, num_classes)

        loss = torch.sum(torch.log(1 + pos * mask / 2)) / (
            self.alpha * num_classes_in_batch
        ) + torch.sum(torch.log(1 + neg * (1 - mask))) / (
            2 * self.beta * num_classes_in_batch
        )

        return loss


class MeanFieldClassWiseMultiSimilarityLoss(BaseMeanFieldLoss):
    def __init__(
        self,
        num_classes,
        embedding_size,
        alpha,
        beta,
        base,
        mf_reg,
        mf_power,
        pos_margin=None,
        neg_margin=None,
        **kwargs,
    ):
        super().__init__(
            num_classes=num_classes,
            embedding_size=embedding_size,
            **kwargs,
        )
        self.alpha = alpha
        self.beta = beta if not beta is None else alpha
        self.pos_margin = base if pos_margin is None else pos_margin
        self.neg_margin = base if neg_margin is None else neg_margin
        self.mf_reg = mf_reg
        self.mf_power = mf_power

    def get_loss(self, embeddings, labels):
        embeddings = self.scale_embeddings(embeddings)
        num_classes_in_batch = torch.unique(labels).shape[0]
        distance_from_class_vec = self.get_distance_from_class_vec(
            embeddings=embeddings
        )  # 1 - emb @ W >= 0 on "S" # shape = (batch_size, num_classes)

        target_mask = self.get_target_mask(
            embeddings=embeddings, labels=labels
        )  # shape = (batch_size, num_classes)
        mask = torch.eye(
            n=self.num_classes, device=embeddings.device
        )  # shape = (num_classes, num_classes)

        target_mask_normalized = target_mask / torch.maximum(
            torch.sum(target_mask, dim=0), torch.ones(1, device=target_mask.device)
        )

        pos_logit = torch.exp(
            self.alpha * (distance_from_class_vec + self.pos_margin - 1)
        )
        neg_logit = torch.exp(
            -self.beta * (distance_from_class_vec + self.neg_margin - 1)
        )
        pos = target_mask_normalized.T @ pos_logit
        neg = target_mask_normalized.T @ neg_logit
        neg = neg + neg.T

        loss = torch.sum(torch.log(1 + pos * mask)) / (
            self.alpha * num_classes_in_batch
        ) + torch.sum(torch.log(1 + neg * (1 - mask))) / (
            2 * self.beta * num_classes_in_batch
        )
        return loss

    def get_reg(self):
        class_matrix = self.get_class_distance_matrix()
        mask = torch.eye(self.num_classes, device=class_matrix.device)
        reg = self.mf_reg * torch.mean(
            torch.sum(
                torch.pow(
                    torch.nn.Softplus(self.beta)(1 - self.neg_margin - class_matrix)
                    * (1 - mask),
                    self.mf_power,
                ),
                dim=1,
            )
        )
        return reg

    def forward(
        self,
        embeddings,
        labels,
        indices_tuple=None,  # necessary for compatibility with pytorch-metric-learning
    ):
        dtype, device = embeddings.dtype, embeddings.device
        self.cast_types(dtype, device)

        loss = self.get_loss(embeddings=embeddings, labels=labels)
        reg = self.get_reg()

        total_loss = loss + reg
        return total_loss
