import torch
import torch.nn.functional as F

def clip_loss(text_embedding, crystal_embedding, params):
    """
    Calculate cross entropy loss for text and crystal embeddings.
    This implementation is based on `https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/CLIP.py` .

    Args:
      text_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of texts.
      crystal_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of crystal structures.
      params: Namespace or dict, parameters for loss calculation (loss_scale, reduction_mode).

    Returns:
      loss: float tensor, the calculated triplet loss.

    """
    if hasattr(params, "loss_scale"):
        loss_scale = params.loss_scale
    else:
        loss_scale = 1.0
    logits_per_crystal = loss_scale * crystal_embedding @ text_embedding.T
    logits_per_text = loss_scale * text_embedding @ crystal_embedding.T

    labels = torch.arange(logits_per_text.shape[0], dtype=torch.long, device=logits_per_text.device)

    total_loss = (
        F.cross_entropy(logits_per_crystal, labels) +
        F.cross_entropy(logits_per_text, labels)
    ) / 2.

    return total_loss

def cosface_loss(text_embedding, crystal_embedding, params):
    """
    Calculate CosFace loss for text and crystal embeddings.

    Args:
      text_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of texts.
      crystal_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of crystals.
      m: float, margin

    Returns:
      loss: float tensor, the calculated CosFace loss.
    """
    margin = params.margin
    if hasattr(params, "loss_scale"):
        loss_scale = params.loss_scale
    else:
        loss_scale = 10.0

    # Compute the cosine similarity matrix
    cosine_matrix = text_embedding @ crystal_embedding.T
    
    # Get the number of samples
    batch_size = text_embedding.size(0)

    # Create the label vector
    labels = torch.arange(batch_size, dtype=torch.long, device=text_embedding.device)

    # Apply the margin to the diagonal (where labels match)
    margin_matrix = torch.eye(batch_size, device=text_embedding.device) * margin
    cosine_with_margin = (cosine_matrix - margin_matrix)*loss_scale

    # Calculate the loss using cross-entropy
    loss_text = F.cross_entropy(cosine_with_margin, labels)
    loss_crystal = F.cross_entropy(cosine_with_margin.T, labels)

    # Combine losses
    total_loss = (loss_text + loss_crystal) / 2

    return total_loss