"""
Hard-Negative NCE loss for contrastive learning.
https://arxiv.org/pdf/2301.02280.pdf
"""
# third party
import torch
import torch.nn as nn
import torch.nn.functional as F


class CrossEntropyLoss(nn.Module):
    """
    Hard Negative NCE loss for contrastive learning.
    """

    def __init__(self):
        super(CrossEntropyLoss, self).__init__()

    def forward(self, tar_img_feat: torch.Tensor, query_feat: torch.Tensor, temp):
        device = tar_img_feat.device

        sim_t2q = tar_img_feat @ query_feat.T / temp
        sim_q2t = query_feat @ tar_img_feat.T / temp

        bs = sim_t2q.size(0)
        loss_t2q = F.cross_entropy(sim_t2q, torch.arange(bs, device=device))
        loss_q2t = F.cross_entropy(sim_q2t, torch.arange(bs, device=device))

        return (loss_t2q + loss_q2t) / 2


class HardNegativeNCE(nn.Module):
    """
    Hard Negative NCE loss for contrastive learning.
    """

    def __init__(
        self, temperature: float = 0.07, alpha: float = 1.0, beta: float = 0.0
    ):
        """
        Args:
            temperature: temperature for the softmax
            alpha: rescaling factor for positiver terms
            beta: concentration parameter

        Note:
            alpha = 1 and beta = 0 corresponds to the original Info-NCE loss
        """
        super(HardNegativeNCE, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature

    def forward(
        self,
        video_embds: torch.Tensor,
        text_embds: torch.Tensor,
        hard_negatives: torch.Tensor = None,  # type: ignore
    ):
        """
        Args:
            video_embds: (batch_size, video_embd_dim)
            text_embds: (batch_size, text_embd_dim)
        """
        batch_size = video_embds.size(0)
        # computation of the similarity matrix
        sim_matrix = video_embds @ text_embds.T  # (batch_size, batch_size)
        # scale the similarity matrix with the temperature
        sim_matrix = sim_matrix / self.temperature
        sim_matrix = sim_matrix.float()

        # sim_matrix(i, j) = <v_i, t_j>
        if hard_negatives is not None:
            sim_matrix = sim_matrix + hard_negatives

        nominator = torch.diagonal(sim_matrix)

        beta_sim = self.beta * sim_matrix
        w_v2t = (
            (batch_size - 1)
            * torch.exp(beta_sim)
            / (torch.exp(beta_sim).sum(dim=1) - torch.exp(torch.diagonal(beta_sim)))
        )
        w_t2v = (
            (batch_size - 1)
            * torch.exp(beta_sim)
            / (torch.exp(beta_sim).sum(dim=0) - torch.exp(torch.diagonal(beta_sim)))
        )
        # replace the diagonal terms of w_v2t and w_t2v with alpha
        w_v2t[range(batch_size), range(batch_size)] = self.alpha
        w_t2v[range(batch_size), range(batch_size)] = self.alpha

        denominator_v2t = torch.log((torch.exp(sim_matrix) * w_v2t).sum(dim=1))
        denominator_t2v = torch.log((torch.exp(sim_matrix) * w_t2v).sum(dim=0))

        hn_nce_loss = (denominator_v2t - nominator).mean() + (
            denominator_t2v - nominator
        ).mean()
        return hn_nce_loss


def info_nce_loss(
    tar_img_feat: torch.Tensor,
    query_feat: torch.Tensor,
    temp,
    text_selfsim=None,
    threashold_self_sim=0.6,
):
    device = tar_img_feat.device

    sim_q2t = query_feat @ tar_img_feat.T / temp
    bs = sim_q2t.size(0)

    if text_selfsim is not None and threashold_self_sim:
        text_selfsim_nodiag = text_selfsim - text_selfsim.diag().diag()
        idx = torch.where(text_selfsim_nodiag > threashold_self_sim)
        sim_q2t[idx] = -torch.inf

    labels = torch.arange(bs, device=device)

    total_loss = (
        F.cross_entropy(sim_q2t, labels) + F.cross_entropy(sim_q2t.T, labels)
    ) / 2

    return total_loss


if __name__ == "__main__":
    # sanity check that the loss is working
    # looking whether the loss is equal to the Info-NCE loss
    # when alpha = 1 and beta = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for _ in range(10):
        in_video_embd = torch.randn(1024, 512).to(device)
        in_text_embd = torch.randn(1024, 512).to(device)

        # normalize
        in_video_embd = F.normalize(in_video_embd, dim=-1)
        in_text_embd = F.normalize(in_text_embd, dim=-1)
        loss_fn = HardNegativeNCE(beta=0.0, alpha=1.0)
        loss_fn(in_video_embd, in_text_embd, debug_test=True)
