import torch  

def diversity_loss(P, K):  
    """  
    计算多样性损失 Ldiv。支持二维 (K, Np) 或三维 (B, K, Np) 输入  
    
    参数:  
    - P: 形状为 (K, Np) 或 (B, K, Np) 的张量，表示 K 个关键点的属性嵌入  
    - K: int，关键点的数量  

    返回:  
    - loss: 计算出的多样性损失（标量）  
    """  

    P_transposed = P.transpose(-1, -2)  

    P_outer = torch.matmul(P_transposed, P)  


    I = torch.eye(P.size(-1), device=P.device) 


    diff = P_outer - I  


    if P.dim() == 3:
       
        loss_per_sample = torch.linalg.norm(diff, ord='fro', dim=(1,2)) ** 2  # 形状 (B,)
        loss = (loss_per_sample / K).mean() 
    else:

        loss = torch.norm(diff, p='fro') ** 2 / K

    return loss


B = 2  
K = 5  
Np = 3 

P_3d = torch.randn(B, K, Np)  
loss_3d = diversity_loss(P_3d, K)
print("3D Input Loss:", loss_3d.item())


P_2d = torch.randn(K, Np)     
loss_2d = diversity_loss(P_2d, K)
print("2D Input Loss:", loss_2d.item())