from functools import partial
import torch
from torch import nn, Tensor
from typing import Iterable, Dict, Optional

import torch.distributed as dist


def cos_sim(a: Tensor, b: Tensor) -> Tensor:
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.

    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)

    return torch.mm(a_norm, b_norm.transpose(0, 1))


class MultipleNegativesRankingLoss(nn.Module):
    """
    This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n)
    where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair.

    For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and
    n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

    This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc))
    as it will sample in each batch n-1 negative docs randomly.

    The performance usually increases with increasing batch sizes.

    For more information, see: https://arxiv.org/pdf/1705.00652.pdf
    (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4)

    You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this:
    (a_1, p_1, n_1), (a_2, p_2, n_2)

    Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives.

    Example::

        from sentence_transformers import SentenceTransformer, losses, InputExample
        from torch.utils.data import DataLoader

        model = SentenceTransformer('distilbert-base-uncased')
        train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']),
            InputExample(texts=['Anchor 2', 'Positive 2'])]
        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
        train_loss = losses.MultipleNegativesRankingLoss(model=model)
    """

    def __init__(
        self,
        loss_type: str = "cross_batch_negative",
        mask_k_ldiags: int = None,
        mask_k_udiags: int = None,
        pick_k: int = None,
        k_pos_labels: int = None,
        decay_factor: float = 1.0,
        n_gram: int = None,
        keep_k_cross_device_negatives: int = None,
        compute_k_loss: int = None,
        k_random_pos_labels: int = None,
        mask_full_ldiag: bool = False,
        n_hard_negatives: int = 1,
        scale: float = 20.0,
        similarity_fct=cos_sim,
    ):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super(MultipleNegativesRankingLoss, self).__init__()
        self.loss_type = loss_type
        self.mask_k_ldiags = mask_k_ldiags
        self.mask_k_udiags = mask_k_udiags
        self.n_gram = n_gram
        self.pick_k = pick_k
        self.k_pos_labels = k_pos_labels
        self.decay_factor = decay_factor
        self.keep_k_cross_device_negatives = keep_k_cross_device_negatives
        self.compute_k_loss = compute_k_loss
        self.k_random_pos_labels = k_random_pos_labels
        self.mask_full_ldiag = mask_full_ldiag
        self.target_per_query = n_hard_negatives + 1
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.rank = dist.get_rank()
        self.log_decay_factor = torch.log(torch.tensor(self.decay_factor, dtype=torch.float32))
        # current device based on rank
        device = torch.device("cuda", self.rank) if torch.cuda.is_available() else torch.device("cpu")
        big_prime = 15485863
        max_int = 2 ** 31 - 1
        self.generator = torch.Generator(device).manual_seed((max_int - big_prime)*self.rank)

    @torch._dynamo.disable(recursive=True)
    def distributed_loss(
        self,
        sentence_features: torch.Tensor,
        t_sizes: torch.Tensor,
        group_rank: int,
        track_memory_finegrained: bool = False,
        memory_stats: Optional[Dict[str, float]] = None,
    ) -> torch.Tensor:
        
        memory_stats = {} if memory_stats is None else memory_stats
        def check_memory(rank=self.rank, curr_device=sentence_features[0][0].device):
            memory_allocated_per_gpu = torch.cuda.memory_allocated(curr_device) / 1024**3
            memory_reserved_per_gpu = torch.cuda.memory_reserved(curr_device) / 1024**3
            max_memory_allocated_per_gpu = torch.cuda.max_memory_allocated(curr_device) / 1024**3
            max_memory_reserved_per_gpu = torch.cuda.max_memory_reserved(curr_device) / 1024**3
            # if rank == 0:
            #     print(f"Memory Allocated (curr): {memory_allocated_per_gpu:.2f} GB")
            #     print(f"Memory Allocated (XXXX-13): {max_memory_allocated_per_gpu:.2f} GB")
            #     print(f"Memory Reserved (curr): {memory_reserved_per_gpu:.2f} GB")
            #     print(f"Memory Reserved (XXXX-13): {max_memory_reserved_per_gpu:.2f} GB")
            return {
                "memory_allocated_per_gpu": memory_allocated_per_gpu,
                "max_memory_allocated_per_gpu": max_memory_allocated_per_gpu,
                "memory_reserved_per_gpu": memory_reserved_per_gpu,
                "max_memory_reserved_per_gpu": max_memory_reserved_per_gpu,
            }

        embeddings_a_bsz_T_d = sentence_features[0]
        embeddings_b_bsz_T_d = sentence_features[1]

        if isinstance(embeddings_a_bsz_T_d, list):  # it'd s ragged tensor (because of the padding removal)
            seq_lens = [
                x.size(0) for x in embeddings_a_bsz_T_d
            ]  # we need to know the seq lens for k positive labels creation)
            embeddings_a_bsz_d = torch.cat(embeddings_a_bsz_T_d, dim=0)  # (T_1+T_2+...+T_n, d)
        else:
            embeddings_a_bsz_d = embeddings_a_bsz_T_d.reshape(-1, embeddings_a_bsz_T_d.size(-1))  # (bsz * T, d)

        if isinstance(embeddings_b_bsz_T_d, list):
            embeddings_b_bsz_d = torch.cat(embeddings_b_bsz_T_d, dim=0)  # (T_1+T_2+...+T_n, d)
        else:
            embeddings_b_bsz_d = embeddings_b_bsz_T_d.reshape(-1, embeddings_b_bsz_T_d.size(-1))  # (bsz * T, d)

        if embeddings_a_bsz_d.size(0) != embeddings_b_bsz_d.size(0):
            cum_sized = torch.cumsum(t_sizes, dim=0)
            if self.target_per_query > 1: # means we have hard negatives (phase 3 data type (query, positive, negative))
                local_suffix_start_idx = cum_sized[group_rank] - t_sizes[group_rank]
                local_suffix_end_idx = (cum_sized[group_rank] - t_sizes[group_rank]) + t_sizes[group_rank]
                labels = torch.tensor(
                    range(
                        local_suffix_start_idx,
                        local_suffix_end_idx,
                        self.target_per_query,
                    ),
                    dtype=torch.long,
                    device=embeddings_a_bsz_d.device,
                )
            else:
                labels = torch.tensor(
                    range(
                        cum_sized[group_rank] - t_sizes[group_rank],
                        (cum_sized[group_rank] - t_sizes[group_rank]) + t_sizes[group_rank],
                    ),
                    dtype=torch.long,
                    device=embeddings_a_bsz_d.device,
                )
        else:
            labels = torch.tensor(
                range(0, embeddings_a_bsz_d.size(0) * self.target_per_query, self.target_per_query), dtype=torch.long, device=embeddings_a_bsz_d.device
            )  # Example a[i] should match with b[i]    [0, 1, 2, 3, ...] or [0, 2, 4, 6, ...] if target_per_query = 2 when have 1 hard negative for each pair

        if self.keep_k_cross_device_negatives != None:
            embeddings_b_bsz_d, labels = self.sample_cross_device_negatives(embeddings_b_bsz_d, labels)

        if self.pick_k:
            embeddings_a_bsz_d, embeddings_b_bsz_d, labels = self.pick_k_pairs(
                embeddings_a_bsz_d, embeddings_b_bsz_d, seq_lens, labels
            )

        scores = self.similarity_fct(embeddings_a_bsz_d, embeddings_b_bsz_d) * self.scale  # [b, b]
        if track_memory_finegrained:
            memory_stats.update({f"after_similarity_fct/{k}": v for k, v in check_memory().items()})

        if self.mask_k_ldiags:
            scores = self.mask_k_lower_diags_fast(scores, seq_lens, diag_positions=labels)

        if self.mask_full_ldiag:
            scores = self.mask_full_lower_diag(scores, seq_lens, diag_positions=labels)

        if self.mask_k_udiags:
            scores = self.mask_k_upper_diags(scores, seq_lens, diag_positions=labels)

        if self.k_pos_labels != None:
            # This will be a 2D tensor with 1s and 0s
            labels_w_k_positives = self.create_k_positives(scores, seq_lens, labels)

            # This will be a 2D tensor with probabilities in each row
            labels = (
                self.decay_label_probs(labels_w_k_positives) if self.k_pos_labels > 0 else labels_w_k_positives
            )  # cos we don't want to do decay when it's only 1 label per sample

        if self.k_random_pos_labels != None:
            # This will be a new 1D label tensor with random offsets (we basically pick a random k right-shifted suffix as the positive pair)
            k = torch.randint(0, self.k_random_pos_labels, (1,), device=labels.device, generator=self.generator).item()
            scores, labels = self.create_new_diag_labels_randomly(scores, labels, seq_lens, k=k)

        if self.n_gram:
            scores, labels, keep_indices = self.mask_n_gram_rows(scores, seq_lens, labels)
            if self.k_pos_labels != None:
                labels_w_k_positives = labels_w_k_positives[keep_indices]

        if self.k_pos_labels != None:
            accuracy = self.top_k_recall(
                scores, labels_w_k_positives
            )  # we use `labels_w_k_positives` for accuracy because this var has 0 0 1 1 0 type data (unlike labels which is probability dist. when k_pos_labels > 0)
        else:
            accuracy = (
                (torch.argmax(scores, dim=1) == labels).float().mean()
            )  # we want to check the retrieval accuracy

        loss = self.cross_entropy_loss(scores, labels)

        if self.compute_k_loss:
            # we've already computed the loss and accuracy 1 time above, so now we'll do k more times
            for i in range(self.compute_k_loss):
                new_diag_labels = self.create_new_diag_labels(labels, seq_lens, i + 1)
                accuracy += (torch.argmax(scores, dim=1) == new_diag_labels).float().mean()
                loss += self.cross_entropy_loss(scores, new_diag_labels)

            accuracy = accuracy / (self.compute_k_loss + 1)
            loss = loss / (self.compute_k_loss + 1)

        if track_memory_finegrained:
            memory_stats.update({f"after_ce_calc/{k}": v for k, v in check_memory().items()})

        return loss, accuracy

    def forward(self, sentence_features: list[Tensor, Tensor]):
        embeddings_a_bsz_T_d = sentence_features[0]
        embeddings_b_bsz_T_d = sentence_features[1]

        if self.loss_type == "cross_batch_negative":
            outputs = self.get_cross_batch_negative_loss(embeddings_a_bsz_T_d, embeddings_b_bsz_T_d)

            return outputs
        else:
            # loss_type == "sequence_negative"
            losses = torch.zeros(
                len(embeddings_a_bsz_T_d), dtype=embeddings_a_bsz_T_d[0].dtype, device=embeddings_a_bsz_T_d[0].device
            )
            accuracies = torch.zeros(
                len(embeddings_a_bsz_T_d), dtype=embeddings_a_bsz_T_d[0].dtype, device=embeddings_a_bsz_T_d[0].device
            )
            for i in range(len(embeddings_a_bsz_T_d)):
                outputs = self.get_cross_batch_negative_loss(embeddings_a_bsz_T_d[i], embeddings_b_bsz_T_d[i])
                losses[i], accuracies[i] = outputs["loss"], outputs["accuracy"]

            return {"loss": losses.mean(), "accuracy": accuracies.mean()}

    def get_cross_batch_negative_loss(self, embeddings_a_bsz_T_d, embeddings_b_bsz_T_d):
        if isinstance(embeddings_a_bsz_T_d, list):  # it'd be a ragged tensor (because of the padding removal)
            seq_lens = [
                x.size(0) for x in embeddings_a_bsz_T_d
            ]  # we need to know the seq lens for k positive labels creation (note seq_lens of embs_a and embs_b are same they're pairs)
            embeddings_a_bsz_d = torch.cat(embeddings_a_bsz_T_d, dim=0)  # (T_1+T_2+...+T_n, d)
        else:
            embeddings_a_bsz_d = embeddings_a_bsz_T_d.reshape(-1, embeddings_a_bsz_T_d.size(-1))  # (bsz * T, d)
        if isinstance(embeddings_b_bsz_T_d, list):
            embeddings_b_bsz_d = torch.cat(embeddings_b_bsz_T_d, dim=0)  # (T_1+T_2+...+T_n, d)
        else:
            embeddings_b_bsz_d = embeddings_b_bsz_T_d.reshape(-1, embeddings_b_bsz_T_d.size(-1))  # (bsz * T, d)

        if self.pick_k:
            embeddings_a_bsz_d, embeddings_b_bsz_d = self.pick_k_pairs(embeddings_a_bsz_d, embeddings_b_bsz_d, seq_lens)

        scores = self.similarity_fct(embeddings_a_bsz_d, embeddings_b_bsz_d) * self.scale  # [b, b]

        assert embeddings_a_bsz_d.size(0) * self.target_per_qry == embeddings_b_bsz_d.size(0), "Number of negatives must be equal to number of positives times target_per_query (something is off with hard negatives)"
        labels = torch.tensor(
            range(0, embeddings_a_bsz_d.size(0) * self.target_per_query), self.target_per_query, dtype=torch.long, device=scores.device
        )  # Example a[i] should match with b[i]    [0, 1, 2, 3, ...] or [0, 2, 4, 6, ...] if target_per_query = 2 when have 1 hard negative for each pair

        if self.mask_k_ldiags:
            scores = self.mask_k_lower_diags_fast(scores, seq_lens)
        if self.mask_k_udiags:
            scores = self.mask_k_upper_diags(scores, seq_lens)
        if self.k_pos_labels != None:
            labels_w_k_positives = self.create_k_positives(scores, seq_lens, labels)  # [b, b]
            labels = (
                self.decay_label_probs(labels_w_k_positives) if self.k_pos_labels > 0 else labels_w_k_positives
            )  # cos we don't want to do decay when it's only 1 label per sample
        if self.n_gram:
            scores, labels, keep_indices = self.mask_n_gram_rows(scores, seq_lens, labels)
            if self.k_pos_labels != None:
                labels_w_k_positives = labels_w_k_positives[keep_indices]
        outputs = {
            "loss": 0.0,
            "accuracy": 0.0,
        }
        # softmaxing scores over suffix dimension
        outputs["loss"] += self.cross_entropy_loss(scores, labels)
        if self.k_pos_labels != None:
            outputs["accuracy"] = self.top_k_recall(scores, labels_w_k_positives)
        else:
            outputs["accuracy"] += (
                (torch.argmax(scores, dim=1) == labels).float().mean()
            )  # we want to check the retrieval accuracy

        return outputs

    def create_k_positives(self, scores, seq_lens, diag_positions):
        diag_positions = (
            diag_positions if diag_positions is not None else torch.arange(scores.size(0), device=scores.device)
        )
        n, m = scores.shape
        labels = torch.zeros(n, m, device=scores.device)

        # Set the initial diagonal positions
        labels[torch.arange(n, device=scores.device), diag_positions] = 1.0

        # Create a mask for the upper diagonal elements we want to set to 1
        mask = torch.zeros_like(scores, dtype=torch.bool)

        # Calculate the indices where we want to place 1s
        start = 0
        for i, size in enumerate(seq_lens):
            end = start + size
            for diag in range(1, min(self.k_pos_labels + 1, size)):  # Limit diagonal to size
                col_start = diag_positions[start] + diag
                upper_bound = min(end - diag, scores.size(0) - diag)
                if start < upper_bound:
                    indices = torch.arange(start, upper_bound)
                    col_indices = torch.arange(col_start, col_start + indices.size(0))
                    mask[indices, col_indices] = True
            start = end

        # Use the mask to set the desired elements to 1
        labels = labels.masked_fill(mask, 1.0)
        return labels

    def mask_k_upper_diags(self, scores, seq_lens, diag_positions=None):
        diag_positions = (
            diag_positions if diag_positions is not None else torch.arange(scores.size(0), device=scores.device)
        )

        # Create a mask for the upper diagonal elements we want to set to -inf
        mask = torch.zeros_like(scores, dtype=torch.bool)

        # Calculate the indices where we want to place -inf
        start = 0
        for i, size in enumerate(seq_lens):
            end = start + size
            for diag in range(1, min(self.mask_k_udiags + 1, size)):
                col_start = diag_positions[start] + diag
                upper_bound = min(end - diag, scores.size(0) - diag)
                if start < upper_bound:
                    indices = torch.arange(start, upper_bound)
                    col_indices = torch.arange(col_start, col_start + indices.size(0))
                    mask[indices, col_indices] = True
            start = end

        if self.k_pos_labels is None:
            scores = scores.masked_fill(
                mask, float("-inf")
            )  # we can't use float('-inf') as the mask value because it causes NaNs for CE kernel for 2D labels
        else:
            scores = scores.masked_fill(mask, -1e20)
        return scores

    def mask_k_lower_diags_fast(self, scores, seq_lens, diag_positions=None):
        """
        Efficiently mask lower k diagonals while respecting sequence boundaries and diagonal positions.
        Args:
            scores: input tensor of shape (N, M) where N = sum(seq_lens) and M >= N
            seq_lens: list of sequence lengths that sum to N
            diag_positions: starting column positions for each row's diagonal pattern (consecutive integers)
        """
        device = scores.device
        n_rows = scores.size(0)
        n_cols = scores.size(1)

        # If diag_positions not provided, use default consecutive positions
        diag_positions = (
            diag_positions if diag_positions is not None else torch.arange(scores.size(0), device=scores.device)
        )

        # Create position tensor indicating sequence start positions
        seq_starts = torch.zeros(n_rows, device=device)
        diag_starts = torch.zeros(n_rows, device=device)
        pos = 0
        for i, length in enumerate(seq_lens):
            seq_starts[pos : pos + length] = pos
            diag_starts[pos : pos + length] = diag_positions[pos]
            pos += length

        # Create row and column indices tensors
        row_idx = torch.arange(n_rows, device=device).unsqueeze(1).expand(-1, n_cols)
        rel_cols = torch.arange(n_cols, device=device).unsqueeze(0).expand(n_rows, -1) - diag_starts.unsqueeze(1)

        # Calculate relative positions considering diagonal offsets
        # rel_cols = col_idx - diag_starts.unsqueeze(1)

        # Create sequence boundary mask
        seq_mask = torch.zeros_like(scores, dtype=torch.bool)
        pos = 0
        for i, length in enumerate(seq_lens):
            # Calculate valid column range for this sequence
            col_start = diag_positions[pos]
            col_end = col_start + length
            seq_mask[pos : pos + length, col_start:col_end] = True
            pos += length

        # Create diagonal mask considering the diagonal positions
        diag_mask = (rel_cols >= 0) & (rel_cols < self.mask_k_ldiags) & (rel_cols < (row_idx - seq_starts.unsqueeze(1)))

        # Combine masks
        diag_mask &= seq_mask

        # Apply mask
        if self.k_pos_labels is None:
            scores = scores.masked_fill_(diag_mask, float("-inf"))
        else:
            scores = scores.masked_fill_(diag_mask, -1e20)
        return scores

    def mask_full_lower_diag(self, scores, seq_lens, diag_positions=None):
        """
        Masks out the whole lower diagonal negatives for each sequence in scores matrix (those are considered bad negatives).
        """
        device = scores.device
        mask_value = float("-inf") if self.k_pos_labels is None else -1e20

        # If diag_positions not provided, use default consecutive positions
        diag_positions = (
            diag_positions if diag_positions is not None else torch.arange(scores.size(0), device=scores.device)
        )

        pos = 0
        for i, seq_len in enumerate(seq_lens):
            seq_mask = torch.tril(torch.ones(seq_len, seq_len, device=device), diagonal=-1).bool()
            start_idx = diag_positions[pos]
            end_idx = start_idx + seq_len
            scores[pos : pos + seq_len, start_idx:end_idx] = scores[pos : pos + seq_len, start_idx:end_idx].masked_fill(
                seq_mask, mask_value
            )
            pos += seq_len

        return scores

    def sample_cross_device_negatives(self, embeddings_b, labels):
        """
        Sample k indices from embeddings_b that don't appear in labels while maintaining
        the correct mapping between labels and embeddings.
        """

        bsz = embeddings_b.shape[0]
        device = embeddings_b.device

        all_indices = torch.arange(bsz, device=device)

        # Find indices that are not in labels
        mask = torch.ones(bsz, dtype=torch.bool, device=device)
        mask[labels] = False
        negative_indices = all_indices[mask]

        if self.keep_k_cross_device_negatives > len(negative_indices):
            # if more negatives are requested than available, then we just return all negatives
            return embeddings_b, labels

        # Sample k indices from the negative set
        perm = torch.randperm(len(negative_indices), device=device, generator=self.generator)[: self.keep_k_cross_device_negatives]
        sampled_negative_indices = negative_indices[perm]

        # Combine sampled negative indices with label indices
        final_indices = torch.cat([labels, sampled_negative_indices])

        # Sample the embeddings using the final indices
        sampled_embeddings = embeddings_b[final_indices]

        # New labels are just range(len(labels)) since local suffix embeddings are at the start
        new_labels = torch.arange(len(labels), device=device)

        return sampled_embeddings, new_labels

    def create_new_diag_labels(self, labels, seq_lens, k):
        """
        adds an offset k to elements in sequences where k + label < seq_len
        """
        if isinstance(seq_lens, list):
            seq_lens = torch.tensor(seq_lens, device=labels.device)
        # Get cumulative sequence lengths for determining sequence boundaries
        cum_lens = torch.cat([torch.tensor([0], device=labels.device), torch.cumsum(seq_lens, dim=0)])

        # Create output tensor
        new_labels = labels.clone()

        # For each sequence, add k to valid positions
        for i in range(len(seq_lens)):
            start_idx = cum_lens[i]
            valid_len = XXXX-13(0, seq_lens[i] - k)  # Number of positions where we can add k
            if valid_len > 0:
                valid_positions = start_idx + torch.arange(valid_len, device=labels.device)
                new_labels[valid_positions] += k

        return new_labels

    def create_new_diag_labels_randomly(self, scores, labels, seq_lens, k):
        """
        adds an offset k to elements in sequences where k + label < seq_len
        """
        def mask_k_upper_diags(scores, seq_lens, k_udiags=None, diag_positions=None, new_labels=None):
            diag_positions = (
                diag_positions if diag_positions is not None else torch.arange(scores.size(0), device=scores.device)
            )

            # Create a mask for the upper diagonal elements we want to set to -inf
            mask = torch.zeros_like(scores, dtype=torch.bool)

            # Calculate the indices where we want to place -inf
            start = 0
            for i, size in enumerate(seq_lens):
                end = start + size
                for diag in range(1, min(k_udiags + 1, size)):
                    col_start = diag_positions[start] + diag
                    upper_bound = min(end - diag, scores.size(0) - diag)
                    if start < upper_bound:
                        indices = torch.arange(start, upper_bound)
                        col_indices = torch.arange(col_start, col_start + indices.size(0))
                        mask[indices, col_indices] = True
                start = end

            if new_labels is not None:
                # we'll put false in new_labels positions (which are the column indices)
                row_indices = torch.arange(scores.size(0), device=scores.device)
                mask[row_indices, diag_positions] = True
                mask[row_indices, new_labels] = False

            scores = scores.masked_fill(mask, float("-inf"))
            return scores

        if isinstance(seq_lens, list):
            seq_lens = torch.tensor(seq_lens, device=labels.device)
        # Get cumulative sequence lengths for determining sequence boundaries
        cum_lens = torch.cat([torch.tensor([0], device=labels.device), torch.cumsum(seq_lens, dim=0)])

        # Create output tensor
        new_labels = labels.clone()

        # For each sequence, add k to valid positions
        for i in range(len(seq_lens)):
            start_idx = cum_lens[i]
            valid_len = XXXX-13(0, seq_lens[i] - k)  # Number of positions where we can add k
            if valid_len > 0:
                valid_positions = start_idx + torch.arange(valid_len, device=labels.device)
                new_labels[valid_positions] += k

        scores = mask_k_upper_diags(
            scores,
            seq_lens,
            k_udiags=self.k_random_pos_labels,
            diag_positions=labels,
            new_labels=new_labels if k > 0 else None,
        )

        return scores, new_labels

    def mask_n_gram_rows(self, scores, seq_lens, labels):
        """
        Mask out first and last n rows of each sequence in scores matrix and remove corresponding labels.
        """
        assert scores.size(0) == sum(seq_lens), "Scores rows must match sum of sequence lengths"
        assert len(labels) == scores.size(0), "Labels length must match scores rows"

        # Convert labels to tensor if not already
        labels = torch.tensor(labels) if not torch.is_tensor(labels) else labels

        # Calculate which indices to keep (exclude first and last n from each sequence)
        keep_indices = []
        start_idx = 0

        for seq_len in seq_lens:
            # For each sequence, keep all indices except first and last n
            # Note: Need to ensure we don't overlap when sequence is too short
            n_gram = min(self.n_gram, seq_len // 2)  # Prevent overlapping for short sequences
            seq_indices = torch.arange(start_idx + n_gram, start_idx + seq_len - n_gram)
            keep_indices.append(seq_indices)
            start_idx += seq_len

        # Concatenate all indices to keep
        keep_indices = torch.cat(keep_indices)

        # Apply masking
        masked_scores = scores[keep_indices]
        masked_labels = labels[keep_indices]

        return masked_scores, masked_labels, keep_indices

    def pick_k_pairs(self, embs_a, embs_b, seq_lens, labels=None):
        """
        Randomly pick k pairs from every sequence.
        """
        assert embs_a.size(0) == embs_b.size(0), "Tensors must have the same batch size"
        assert self.pick_k > 0, "We can't pick 0 pairs, make pick_k higher"

        # embs_a has len(seq_lens) sequences, so we need to pick k pairs from each sequence
        # e.g. if seq_lens = [4, 5, 6], and self.pick_k=1, we need to randomly pick 1 pair from the first sequence, 1 from the second, and 1 from the third

        # Generate random unique indices to pick from each sequence
        selected_indices = []
        start_idx = 0

        for seq_len in seq_lens:
            # For each sequence, randomly pick k indices
            if seq_len < self.pick_k:
                # If sequence is too short, pick all indices
                seq_indices = torch.arange(start_idx, start_idx + seq_len)
            else:
                # Otherwise, randomly pick k indices
                seq_indices = torch.randperm(seq_len)[: self.pick_k] + start_idx
            selected_indices.append(seq_indices)
            start_idx += seq_len

        # Concatenate all selected indices
        selected_indices = torch.cat(selected_indices)

        # Pick the embeddings using the selected indices
        picked_embs_a = embs_a[selected_indices]
        picked_embs_b = embs_b[selected_indices]

        if labels is not None:
            # Pick the labels using the selected indices
            picked_labels = labels[selected_indices]
            return picked_embs_a, picked_embs_b, picked_labels

        return picked_embs_a, picked_embs_b

    def decay_label_probs(self, labels):
        # Compute log decay factors only for non-zero elements
        nonzero = labels.nonzero(as_tuple=True)
        log_decay = nonzero[1].to(torch.float32) * self.log_decay_factor.to(labels.device)

        # Compute log probabilities
        log_probs = torch.full_like(labels, float("-inf"), dtype=torch.float32)
        log_probs[nonzero] = log_decay

        # Normalize in log space
        return (log_probs - torch.logsumexp(log_probs, dim=1, keepdim=True)).exp()

    def top_k_recall(self, predictions, labels):
        # Get the indices of the top k predictions
        _, top_k_indices = predictions.topk(
            self.k_pos_labels + 1, dim=1
        )  # we do k+1 since in k=1 setting we've two labels per sample and k=0 means the true labels are in diagonal

        total_true = labels.sum()
        correct_count = 0

        for i in range(labels.shape[0]):  # Iterate over each sample
            sample_labels = labels[i].nonzero().squeeze(1)
            sample_predictions = top_k_indices[i]
            correct_count += torch.sum(torch.isin(sample_predictions, sample_labels))

        # Calculate recall
        recall = correct_count / total_true if total_true > 0 else 0.0
        return recall

    def get_config_dict(self):
        return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
