import torch
from torch.nn.modules.loss import _Loss
import torch.nn as nn
import torch.nn.functional as F
from pointcept.models.losses.builder import LOSSES

def lj_loss(features, sigma=1.0, n=6, clamp_max=5.0):
    """
    Lennard-Jones loss function using cosine distance for a batch of feature vectors.

    Args:
        features: Tensor of shape (batch_size, num_features), the feature vectors for points.
        epsilon: The depth of the potential well, controlling the strength of the interaction.
        sigma: The distance at which the potential is zero (ideal distance in cosine space).
        delta: Small value to prevent division by zero or very small distances.

    Returns:
        lj_loss: The computed Lennard-Jones loss for the batch.
    """
    features_normalized = F.normalize(features, p=2, dim=1)
    cosine_sim = torch.matmul(features_normalized, features_normalized.T)  # Cosine similarity matrix

    # Convert cosine similarity to cosine distance (in range [0, 2])
    cosine_dist = 1 - cosine_sim

    # sigma = torch.where(cosine_dist < 0.2, 0.2,  # 如果 cosine_dist < 0.5, sigma = 0.5
    #                     torch.where(cosine_dist > 1.5, 2, 2))
    diag_indices = torch.arange(cosine_dist.size(0))
    cosine_dist[diag_indices, diag_indices] = sigma

    # # Clamp distances to avoid division by zero or very small values
    cosine_dist = torch.clamp(cosine_dist, min=1e-3)  # Adjust delta to a larger value

    term1 = (sigma / cosine_dist) ** (2*n)
    term2 = ((sigma / cosine_dist) ** n)
    lj_potential = (term1 - term2)

    # lj_potential = F.normalize(lj_potential, p=2)

    # Optionally clamp the potential to avoid extreme loss values
    lj_potential = torch.clamp(lj_potential, max=clamp_max)

    lj_loss = torch.mean(lj_potential)

    return lj_loss


@LOSSES.register_module()
class LJLoss(nn.Module):
    """
    Default settings of LJ_loss
    cls:epsilon=0.1, sigma=0.5
    seg:epsilon=4e-3, sigma=1.
    """
    def __init__(self, epsilon=0.1, sigma=0.5, n=6, clamp_max=5.):
        super(LJLoss, self).__init__()
        self.sigma = sigma
        self.n = n
        self.clamp_max = clamp_max
        self.epsilon = epsilon

    def forward(self, feat):
        return self.epsilon * lj_loss(feat, sigma=self.sigma, n=self.n, clamp_max=self.clamp_max)
