#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Loss functions for point cloud classification with contrastive learning
- NT-Xent contrastive loss (for inter-object contrast)
- CrossEntropyLoss (for classification)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


def contrastive_nt_xent(feats: torch.Tensor,
                         labels: torch.Tensor,
                         obj_ids: list,
                         temperature: float = 0.1,
                         eps: float = 1e-8) -> torch.Tensor:
    """
    NT-Xent (Normalized Temperature-Scaled Cross-Entropy) contrastive loss
    Positive pairs: Same class + Different object
    Negative pairs: All other samples (excluding self)
    Args:
        feats: Global feature embeddings from PointNet (B, 256)
        labels: Class labels (B,)
        obj_ids: List of object IDs (length B)
        temperature: Temperature scaling factor for similarity
        eps: Small epsilon to avoid division by zero and log(0)
    Returns:
        Scalar contrastive loss tensor
    """
    device = feats.device
    batch_size = feats.shape[0]

    # Return 0 loss if batch size is 1 (no positive/negative pairs)
    if batch_size <= 1:
        return torch.tensor(0., device=device)

    # L2 normalization of features (required for cosine similarity)
    feats_norm = F.normalize(feats, p=2, dim=1)

    # Compute cosine similarity matrix (B, B)
    sim_matrix = torch.matmul(feats_norm, feats_norm.t()) / temperature

    # Reshape labels for easy comparison
    labels = labels.view(-1)

    # Build positive mask: same class AND different object (i != j)
    pos_mask = torch.zeros((batch_size, batch_size), dtype=torch.bool, device=device)
    for i in range(batch_size):
        for j in range(batch_size):
            if i == j:
                continue
            same_class = (labels[i].item() == labels[j].item())
            different_obj = (obj_ids[i] != obj_ids[j])
            if same_class and different_obj:
                pos_mask[i, j] = True

    # Calculate total loss and count valid samples (with at least one positive pair)
    total_loss = torch.tensor(0., device=device)
    valid_count = 0

    for i in range(batch_size):
        pos_indices = pos_mask[i]
        # Skip if no positive pairs for current sample
        if pos_indices.sum() == 0:
            continue
        valid_count += 1

        # Get similarity scores for current sample, mask out self
        self_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        self_mask[i] = False
        sim_i = sim_matrix[i][self_mask]
        pos_sim_i = sim_matrix[i][pos_indices]

        # Compute exponential of similarity scores
        exp_sim = torch.exp(sim_i)
        exp_pos_sim = torch.exp(pos_sim_i).sum()

        # NT-Xent loss for single sample
        sample_loss = -torch.log((exp_pos_sim + eps) / (exp_sim.sum() + eps))
        total_loss += sample_loss

    # Return average loss over valid samples (or 0 if no valid samples)
    if valid_count == 0:
        return torch.tensor(0., device=device)
    return total_loss / valid_count


# Classification loss (Cross Entropy)
classification_loss = nn.CrossEntropyLoss()