import warnings    
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

def clip_loss(text_embedding, crystal_embedding, params):
    """
    Calculate cross entropy loss for text and crystal embeddings.
    This implementation is based on `https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/CLIP.py` .

    Args:
      text_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of texts.
      crystal_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of crystal structures.
      params: Namespace or dict, parameters for loss calculation (loss_scale, reduction_mode).

    Returns:
      loss: float tensor, the calculated triplet loss.

    """
    if hasattr(params, "loss_scale"):
        loss_scale = params.loss_scale
    else:
        loss_scale = 1.0
    logits_per_crystal = loss_scale * crystal_embedding @ text_embedding.T
    logits_per_text = loss_scale * text_embedding @ crystal_embedding.T

    labels = torch.arange(logits_per_text.shape[0], dtype=torch.long, device=logits_per_text.device)

    total_loss = (
        F.cross_entropy(logits_per_crystal, labels) +
        F.cross_entropy(logits_per_text, labels)
    ) / 2.

    return total_loss

# CosFace
def cosface_loss(text_embedding, crystal_embedding, params):
    """
    Calculate CosFace loss for text and crystal embeddings.

    Args:
      text_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of texts.
      crystal_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of crystals.
      m: float, margin

    Returns:
      loss: float tensor, the calculated CosFace loss.
    """
    margin = params.margin
    if hasattr(params, "loss_scale"):
        loss_scale = params.loss_scale
    else:
        loss_scale = 10.0

    # Compute the cosine similarity matrix
    cosine_matrix = text_embedding @ crystal_embedding.T
    
    # Get the number of samples
    batch_size = text_embedding.size(0)

    # Create the label vector
    labels = torch.arange(batch_size, dtype=torch.long, device=text_embedding.device)

    # Apply the margin to the diagonal (where labels match)
    margin_matrix = torch.eye(batch_size, device=text_embedding.device) * margin
    cosine_with_margin = (cosine_matrix - margin_matrix)*loss_scale

    # Calculate the loss using cross-entropy
    loss_text = F.cross_entropy(cosine_with_margin, labels)
    loss_crystal = F.cross_entropy(cosine_with_margin.T, labels)

    # Combine losses
    total_loss = (loss_text + loss_crystal) / 2

    return total_loss


# ArcFace
def arcface_loss(text_embedding, crystal_embedding, params):
    """
    Calculate ArcFace loss for text and crystal embeddings.
    
    Args:
        text_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of texts.
        crystal_embedding: Tensor, shape [batch_size, embedding_size], feature embeddings of crystals.
        params: Object containing the following attributes:
            margin: float, angular margin
            scale: float, scaling factor for input embeddings
    
    Returns:
        loss: float tensor, the calculated ArcFace loss.
    """
    margin = params.margin
    if hasattr(params, "loss_scale"):
        loss_scale = params.loss_scale
    else:
        loss_scale = 10.0

    # Compute the cosine similarity matrix
    cosine_matrix = text_embedding @ crystal_embedding.T

    # Get the number of samples
    batch_size = text_embedding.size(0)

    # Create the label vector
    labels = torch.arange(batch_size, dtype=torch.long, device=text_embedding.device)

    # Apply the margin to the diagonal (where labels match)
    diag_cos = cosine_matrix.diagonal()
    eps = 1e-5
    diag_cos = torch.clamp(diag_cos, -1+eps, 1-eps)
    diag_theta_margin = torch.arccos(diag_cos) + margin # put margin in angle space
    diag_cos_margin = torch.cos(diag_theta_margin) # invert to cos

    # Calculate the loss using cross-entropy
    cosine_matrix_with_margin = torch.diagonal_scatter(cosine_matrix, diag_cos_margin, 0)
    loss_text = F.cross_entropy(loss_scale * (cosine_matrix + cosine_matrix_with_margin), labels)
    loss_crystal = F.cross_entropy(loss_scale * (cosine_matrix.T + cosine_matrix_with_margin.T), labels)

    # Combine losses
    total_loss = (loss_text + loss_crystal) / 2

    return total_loss