"""
Cross-modal alignment loss functions.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


class CLIPLoss(nn.Module):
    """
    CLIP-style contrastive learning loss for cross-modal alignment.
    """
    
    def __init__(self, temperature: float = 0.07, normalize: bool = True):
        super().__init__()
        self.temperature = temperature
        self.normalize = normalize
        
    def forward(self, features1: torch.Tensor, features2: torch.Tensor) -> torch.Tensor:
        """
        Compute CLIP loss between two sets of features.
        
        Args:
            features1: [batch_size, feature_dim]
            features2: [batch_size, feature_dim]
            
        Returns:
            CLIP loss scalar
        """
        if self.normalize:
            features1 = F.normalize(features1, dim=-1)
            features2 = F.normalize(features2, dim=-1)
            

        if features1.size(0) == 1:

            similarity = F.cosine_similarity(features1, features2, dim=-1)

            loss = 1.0 - similarity.mean()
        else:
            # Compute similarity matrix
            logits = torch.matmul(features1, features2.T) / self.temperature
            
            # Labels are diagonal (positive pairs)
            labels = torch.arange(features1.size(0), device=features1.device)
            
            # Compute cross-entropy loss
            loss = F.cross_entropy(logits, labels)
        
        return loss


class InfoNCELoss(nn.Module):
    """
    InfoNCE loss for contrastive learning.
    """
    
    def __init__(self, temperature: float = 0.07, normalize: bool = True):
        super().__init__()
        self.temperature = temperature
        self.normalize = normalize
        
    def forward(self, features1: torch.Tensor, features2: torch.Tensor, 
                negative_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute InfoNCE loss.
        
        Args:
            features1: [batch_size, feature_dim] - anchor features
            features2: [batch_size, feature_dim] - positive features
            negative_features: [num_negatives, feature_dim] - negative features (optional)
            
        Returns:
            InfoNCE loss scalar
        """
        if self.normalize:
            features1 = F.normalize(features1, dim=-1)
            features2 = F.normalize(features2, dim=-1)
            
        # Positive pairs
        positive_sim = torch.sum(features1 * features2, dim=-1) / self.temperature
        
        if negative_features is not None:
            if self.normalize:
                negative_features = F.normalize(negative_features, dim=-1)
            # Negative pairs
            negative_sim = torch.matmul(features1, negative_features.T) / self.temperature
            logits = torch.cat([positive_sim.unsqueeze(-1), negative_sim], dim=-1)
        else:
            # Use other samples in batch as negatives
            logits = torch.matmul(features1, features2.T) / self.temperature
            
        # Labels are 0 (positive pairs)
        labels = torch.zeros(features1.size(0), device=features1.device, dtype=torch.long)
        
        loss = F.cross_entropy(logits, labels)
        return loss


class ContrastiveLoss(nn.Module):
    """
    General contrastive loss with configurable margin.
    """
    
    def __init__(self, margin: float = 1.0, normalize: bool = True):
        super().__init__()
        self.margin = margin
        self.normalize = normalize
        
    def forward(self, features1: torch.Tensor, features2: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """
        Compute contrastive loss.
        
        Args:
            features1: [batch_size, feature_dim]
            features2: [batch_size, feature_dim]
            labels: [batch_size] - 1 for positive pairs, 0 for negative pairs
            
        Returns:
            Contrastive loss scalar
        """
        if self.normalize:
            features1 = F.normalize(features1, dim=-1)
            features2 = F.normalize(features2, dim=-1)
            
        # Compute distances
        distances = torch.norm(features1 - features2, dim=-1)
        
        # Positive pairs: minimize distance
        positive_loss = labels * distances.pow(2)
        
        # Negative pairs: maximize distance (with margin)
        negative_loss = (1 - labels) * F.relu(self.margin - distances).pow(2)
        
        loss = (positive_loss + negative_loss).mean()
        return loss


class TemporalAlignmentLoss(nn.Module):
    """
    Loss for temporal alignment between modalities.
    """
    
    def __init__(self, max_lag: int = 10, temperature: float = 0.1):
        super().__init__()
        self.max_lag = max_lag
        self.temperature = temperature
        
    def forward(self, features1: torch.Tensor, features2: torch.Tensor,
                timestamps1: torch.Tensor, timestamps2: torch.Tensor) -> torch.Tensor:
        """
        Compute temporal alignment loss.
        
        Args:
            features1: [batch_size, seq_len, feature_dim]
            features2: [batch_size, seq_len, feature_dim]
            timestamps1: [batch_size, seq_len]
            timestamps2: [batch_size, seq_len]
            
        Returns:
            Temporal alignment loss scalar
        """
        batch_size, seq_len, feature_dim = features1.shape
        
        # Compute temporal differences
        time_diff = timestamps1.unsqueeze(-1) - timestamps2.unsqueeze(-2)  # [B, L1, L2]
        
        # Create temporal alignment matrix
        alignment_matrix = torch.exp(-time_diff.abs() / self.temperature)
        
        # Normalize features
        features1_norm = F.normalize(features1, dim=-1)
        features2_norm = F.normalize(features2, dim=-1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features1_norm, features2_norm.transpose(-2, -1))
        
        # Align similarity with temporal alignment
        aligned_similarity = similarity_matrix * alignment_matrix
        
        # Loss: maximize aligned similarity
        loss = -aligned_similarity.mean()
        
        return loss


class CrossModalConsistencyLoss(nn.Module):
    """
    Loss for ensuring consistency across different modalities.
    """
    
    def __init__(self, consistency_weight: float = 1.0):
        super().__init__()
        self.consistency_weight = consistency_weight
        
    def forward(self, modality_features: dict, fusion_features: torch.Tensor) -> torch.Tensor:
        """
        Compute cross-modal consistency loss.
        
        Args:
            modality_features: dict of modality-specific features
            fusion_features: [batch_size, feature_dim] - fused features
            
        Returns:
            Consistency loss scalar
        """
        consistency_loss = 0.0
        
        for modality_name, features in modality_features.items():
            if features.dim() > 2:
                features = features.mean(dim=1)  # [batch_size, feature_dim]
                
            # Normalize features
            features_norm = F.normalize(features, dim=-1)
            fusion_norm = F.normalize(fusion_features, dim=-1)
            
            # Compute cosine similarity
            similarity = torch.sum(features_norm * fusion_norm, dim=-1)
            
            # Loss: maximize similarity (minimize negative similarity)
            consistency_loss += -similarity.mean()
            
        return self.consistency_weight * consistency_loss 