import torch
import torch.nn.functional as F

## loss TCL
def clip_contrastive_loss_negative(image_embeddings, text_embeddings, negative_text_embeddings, temperature=0.07):
    """
    Calculate the CLIP contrastive loss between image and text embeddings.

    Args:
    - image_embeddings (torch.Tensor): Embeddings from the vision encoder of shape (batch_size, embedding_dim).
    - text_embeddings (torch.Tensor): Embeddings from the text encoder of shape (batch_size, embedding_dim).
    - temperature (float): Temperature parameter for scaling the logits.

    Returns:
    - Tuple[torch.Tensor, float]: A tuple containing the contrastive loss value and the accuracy.
    """
    # Normalize the embeddings
    image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
    text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
    negative_text_embeddings = F.normalize(negative_text_embeddings, p=2, dim=1)

    # Compute the similarity matrix
    similarity = torch.matmul(image_embeddings, text_embeddings.t()) / temperature
    similarity_negative = torch.matmul(image_embeddings, negative_text_embeddings.t()) / temperature

    # Labels for matching pairs
    labels = torch.arange(similarity.size(0), device=similarity.device)

    # Calculate the loss for both image-to-text and text-to-image
    loss_i2t = F.cross_entropy(torch.cat((similarity, similarity_negative), dim=-1), labels)
    loss_t2i = F.cross_entropy(similarity.t(), labels)

    # Calculate accuracy
    preds_i2t = torch.cat((similarity, similarity_negative), dim=-1).argmax(dim=1)
    preds_t2i = similarity.t().argmax(dim=1)
    acc_i2t = (preds_i2t == labels).float().mean().item()
    acc_t2i = (preds_t2i == labels).float().mean().item()
    accuracy = (acc_i2t + acc_t2i) / 2

    # Final loss is the average of both losses
    loss = (loss_i2t + loss_t2i) / 2

    return loss, accuracy

## standard clip loss
def clip_contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):
    """
    Calculate the CLIP contrastive loss between image and text embeddings.

    Args:
    - image_embeddings (torch.Tensor): Embeddings from the vision encoder of shape (batch_size, embedding_dim).
    - text_embeddings (torch.Tensor): Embeddings from the text encoder of shape (batch_size, embedding_dim).
    - temperature (float): Temperature parameter for scaling the logits.

    Returns:
    - Tuple[torch.Tensor, float]: A tuple containing the contrastive loss value and the accuracy.
    """
    # Normalize the embeddings
    image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
    text_embeddings = F.normalize(text_embeddings, p=2, dim=1)

    # Compute the similarity matrix
    similarity = torch.matmul(image_embeddings, text_embeddings.t()) / temperature

    # Labels for matching pairs
    labels = torch.arange(similarity.size(0), device=similarity.device)

    # Calculate the loss for both image-to-text and text-to-image
    loss_i2t = F.cross_entropy(similarity, labels)
    loss_t2i = F.cross_entropy(similarity.t(), labels)

    # Calculate accuracy
    preds_i2t = similarity.argmax(dim=1)
    preds_t2i = similarity.t().argmax(dim=1)
    acc_i2t = (preds_i2t == labels).float().mean().item()
    acc_t2i = (preds_t2i == labels).float().mean().item()
    accuracy = (acc_i2t + acc_t2i) / 2

    # Final loss is the average of both losses
    loss = (loss_i2t + loss_t2i) / 2

    return loss, accuracy
