import torch

import torch.nn as nn

def contrastive_loss(temp=0.07, epsilon=0.1):
    """
    Create a callable contrastive loss function with preset temp and epsilon.

    Args:
        temp (float): Temperature scaling factor.
        epsilon (float): Weight for combining NLL and contrastive loss.

    Returns:
        Callable: A function that computes contrastive loss with preconfigured temp and epsilon.
    """
    def loss_fn(model_output, pos_neg_labels, num_items_in_batch=None):
        """
        Compute contrastive loss with the given model outputs and labels.

        Args:
            model_output: Model outputs containing 'hidden_states' and 'logits'.
            labels: Ground truth labels.
            pos_neg_labels: Positive-negative labels.
            num_items_in_batch: Number of items in the batch (optional).

        Returns:
            torch.Tensor: Combined contrastive and NLL loss.
        """
        # Normalize embeddings
        embeddings = torch.nn.functional.normalize(model_output["hidden_states"][-1], p=2, dim=-1)
        last_token_embeddings = embeddings[:, -1, :]

        # Determine batch size
        batch_size = num_items_in_batch or embeddings.shape[0]
        # Contrastive loss
        cosine_sim = torch.matmul(last_token_embeddings, last_token_embeddings.T)

        label_2d = pos_neg_labels
        mask = ~torch.eye(batch_size, dtype=torch.bool, device=cosine_sim.device)
        dis = cosine_sim[mask].view(batch_size, -1)

        # Apply temperature scaling and exponential
        dis = torch.exp(dis / temp)
        cosine_sim = torch.exp(cosine_sim / temp)
        row_sum = dis.sum(dim=1, keepdim=True)

        # Pairwise label comparison
        label_matrix = label_2d == label_2d.T
        label_mask = label_matrix & ~torch.eye(batch_size, dtype=torch.bool, device=label_2d.device)

        # Log probabilities for valid pairs
        cl_log_probs = torch.log(cosine_sim / row_sum)
        masked_log_probs = cl_log_probs * label_mask

        # Compute positive pair normalization
        n_i = label_mask.sum(dim=1)
        valid_mask = n_i > 0
        cl_loss = -masked_log_probs.sum(dim=1) / (n_i + (~valid_mask).float())
        cl_loss = cl_loss[valid_mask].mean()

        return cl_loss,epsilon

    return loss_fn