import torch
import torch.nn.functional as F

from mmhug.registry import HF_MODELS


@HF_MODELS.register_module()
class InfoNCE:
    """
    This class implements the InfoNCE loss function.

    Attributes:
    - t: a temperature parameter for the softmax function in the loss calculation.

    Methods:
    - __call__: computes the InfoNCE loss given the motion and text features, and an optional distance matrix.
    """

    def __init__(self, t: float = 0.1, threshold_selfsim: float = 0.75):
        """
        Initializes the InfoNCE object with a given temperature parameter.

        Inputs:
        - t: a temperature parameter for the softmax function in the loss calculation.
        - threshold_selfsim: a threshold for self-similarity in the distance matrix. If two features are too similar, they are masked in the loss calculation.
        """
        self.t = t
        self.threshold_selfsim = threshold_selfsim

    def __call__(self, modal_a, modal_b, distance_mat_a=None, distance_mat_b=None):
        """
        Computes the InfoNCE loss given the motion and text features, and an optional distance matrix.

        Inputs:
        - modal_a: Feature from modal_a encoder. Shape in [N, d]
        - modal_b: Feature from modal_b encoder. Shape in [N, d]
        - distance_mat_a: distance matrix between modal_a. Shape in [N, N]
        - distance_mat_b: distance matrix between modal_b. Shape in [N, N]

        Outputs:
        - loss_a: the InfoNCE loss computed using the modal a features.
        - loss_b: the InfoNCE loss computed using the modal b features.
        """
        t = self.t

        N, d = modal_a.shape[0], modal_a.shape[1]
        assert (
            modal_b.shape[1] == d
        ), f"modal_a and modal_b should have the same feature dimension, but got {modal_a.shape[1]} and {modal_b.shape[1]}"

        # Normalize the motion and text features
        normalized_a = F.normalize(modal_a, dim=1)
        normalized_b = F.normalize(modal_b, dim=1)

        # Compute the logits as the dot product of the normalized features
        t = torch.tensor(t).to(normalized_a.device)
        # N * N
        logits = torch.mm(normalized_a, normalized_b.T) / t

        # If a distance matrix is provided, use it to mask the logits
        if distance_mat_a is not None:
            logits_a = distance_mat_a.detach()
            mask = torch.where(
                torch.logical_and(logits_a > self.threshold_selfsim, logits_a < 1.0),
                torch.tensor(-torch.inf).to(logits_a),
                torch.tensor(torch.inf).to(logits_a),
            )
            mask.diagonal().fill_(float("inf"))
            # count masked rate
            logits = torch.min(mask, logits)

        if distance_mat_b is not None:
            logits_b = distance_mat_b.detach()
            mask = torch.where(
                torch.logical_and(logits_b > self.threshold_selfsim, logits_b < 1.0),
                torch.tensor(-torch.inf).to(logits_b),
                torch.tensor(torch.inf).to(logits_b),
            )
            mask.diagonal().fill_(float("inf"))
            logits = torch.min(mask, logits)

        # Compute the labels as the indices of the features
        labels = torch.arange(N).to(logits_a.device)
        # Compute the InfoNCE loss for the motion and text features
        loss_a = F.cross_entropy(logits, labels)
        loss_b = F.cross_entropy(logits.T, labels)

        loss = (loss_a + loss_b) / 2
        return loss

    def __repr__(self):
        return "InfoNCE()"


@HF_MODELS.register_module()
class SyncInfoNCE:
    """
    Multi-positive InfoNCE that treats +/-k temporal neighbors as positives.
    Optionally unions positives with a provided distance matrix (same-modal similarity).

    Args:
        t: temperature (float)
        k: temporal radius (int) -- +/- k frames are positives
        threshold_selfsim: if distance matrix provided, entries > threshold are also positives
        large_neg: value used to mask out non-positives when computing numerator (large negative)
    """

    def __init__(
        self,
        t: float = 0.1,
        k: int = 1,
        threshold_selfsim: float = 0.75,
        large_neg: float = -1e9,
    ):
        self.t = float(t)
        self.k = int(k)
        self.threshold_selfsim = float(threshold_selfsim)
        self.large_neg = float(large_neg)

    def __call__(
        self,
        modal_a: torch.Tensor,
        modal_b: torch.Tensor,
        distance_mat_a: torch.Tensor = None,
        distance_mat_b: torch.Tensor = None,
    ):
        """
        modal_a: (N, d)
        modal_b: (N, d)
        distance_mat_a: optional (N, N) matrix (same-modal similarity for modal_a)
        distance_mat_b: optional (N, N) matrix (same-modal similarity for modal_b)

        returns:
            loss (scalar): averaged bidirectional multi-positive InfoNCE loss
        """
        assert modal_a.dim() == 2 and modal_b.dim() == 2
        N, da = modal_a.shape
        Nb, db = modal_b.shape
        assert N == Nb, "modal_a and modal_b must have same batch size N"
        assert da == db, "feature dims must match"

        device = modal_a.device

        # Normalize
        a = F.normalize(modal_a, dim=1)
        b = F.normalize(modal_b, dim=1)

        # logits: (N, N)
        logits = torch.mm(a, b.t()) / self.t  # similarity / temperature

        # Build temporal ±k positive mask (N,N)
        idx = torch.arange(N, device=device)
        diff = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs()  # (N,N)
        pos_mask = diff <= self.k  # include self; shape (N,N), bool

        # Optionally union with distance-based positives if provided
        # If dist matrices are similarities, include entries > threshold as positives
        if distance_mat_b is not None:
            assert distance_mat_b.shape == (N, N)
            pos_mask = pos_mask | (distance_mat_b.to(device) > self.threshold_selfsim)

        # For the reverse direction (b -> a), positives are transpose (or union with distance_mat_a)
        pos_mask_T = pos_mask.t()
        if distance_mat_a is not None:
            assert distance_mat_a.shape == (N, N)
            pos_mask_T = pos_mask_T | (
                distance_mat_a.to(device) > self.threshold_selfsim
            )

        # Ensure diagonal (self) is positive (safety)
        pos_mask.fill_diagonal_(True)
        pos_mask_T.fill_diagonal_(True)

        # Compute row-wise logsumexp for denominator: logsumexp(logits, dim=1)
        denom_row = torch.logsumexp(logits, dim=1)  # (N,)

        # For numerator, mask non-positives to large negative so they don't contribute to logsumexp
        numer_logits = logits.masked_fill(~pos_mask, self.large_neg)  # (N,N)
        numer_row = torch.logsumexp(numer_logits, dim=1)  # (N,)

        loss_a = -(numer_row - denom_row)  # (N,)

        # Reverse direction: logits^T, same procedure
        logits_T = logits.t()
        denom_row_T = torch.logsumexp(logits_T, dim=1)
        numer_logits_T = logits_T.masked_fill(~pos_mask_T, self.large_neg)
        numer_row_T = torch.logsumexp(numer_logits_T, dim=1)
        loss_b = -(numer_row_T - denom_row_T)

        loss = 0.5 * (loss_a.mean() + loss_b.mean())
        return loss
