"""
Class-Balanced Focal Loss for Multi-label Classification

Implementation based on:
- "Focal Loss for Dense Object Detection" (Lin et al., 2017)
- "Class-Balanced Loss Based on Effective Number of Samples" (Cui et al., 2019)

This loss is specifically designed for imbalanced multi-label datasets like ADE20K.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ClassBalancedFocalLoss(nn.Module):
    """
    Class-Balanced Focal Loss for handling class imbalance in multi-label classification
    
    Combines:
    1. Focal Loss - focus on hard examples
    2. Class-balanced weighting - re-weight based on effective number of samples
    """
    
    def __init__(self, 
                 alpha: float = 1.0,
                 gamma: float = 2.0,
                 beta: float = 0.9999,
                 class_frequencies: list = None,
                 eps: float = 1e-8):
        """
        Initialize Class-Balanced Focal Loss
        
        Args:
            alpha: Scaling factor for the loss
            gamma: Focusing parameter (higher gamma = more focus on hard examples)
            beta: Re-weighting factor (closer to 1 = more re-weighting for rare classes)
            class_frequencies: List of class frequencies in training set (for re-weighting)
            eps: Small epsilon to prevent numerical issues
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.beta = beta
        self.eps = eps
        
        # Calculate class weights based on effective number of samples
        if class_frequencies is not None:
            self.class_weights = self._calculate_class_weights(class_frequencies, beta)
            print(f"Class-Balanced Focal Loss initialized with custom class weights")
            print(f"  Alpha: {alpha}, Gamma: {gamma}, Beta: {beta}")
            print(f"  Weight range: {self.class_weights.min():.4f} - {self.class_weights.max():.4f}")
        else:
            self.class_weights = None
            print(f"Class-Balanced Focal Loss initialized with uniform weights")
            print(f"  Alpha: {alpha}, Gamma: {gamma}")
    
    def _calculate_class_weights(self, class_frequencies, beta):
        """
        Calculate class weights based on effective number of samples
        
        Effective Number = (1 - β^n) / (1 - β)
        Weight = 1 / Effective Number
        """
        class_frequencies = np.array(class_frequencies)
        effective_num = 1.0 - np.power(beta, class_frequencies)
        weights = (1.0 - beta) / effective_num
        
        # Normalize weights to have mean = 1
        weights = weights / weights.mean()
        
        return torch.tensor(weights, dtype=torch.float32)
    
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute Class-Balanced Focal Loss
        
        Args:
            inputs: Model predictions (logits) [batch_size, num_classes]
            targets: Ground truth labels [batch_size, num_classes]
            
        Returns:
            loss: Class-balanced focal loss value
        """
        # Convert logits to probabilities
        probs = torch.sigmoid(inputs)
        
        # Clip probabilities to prevent log(0)
        probs = torch.clamp(probs, self.eps, 1.0 - self.eps)
        
        # Calculate focal weights
        # For positive samples: (1 - p)^gamma
        # For negative samples: p^gamma
        focal_weights_pos = torch.pow(1 - probs, self.gamma)
        focal_weights_neg = torch.pow(probs, self.gamma)
        
        # Calculate cross entropy
        ce_loss_pos = -torch.log(probs)
        ce_loss_neg = -torch.log(1 - probs)
        
        # Combine focal weights with cross entropy
        focal_loss_pos = focal_weights_pos * ce_loss_pos
        focal_loss_neg = focal_weights_neg * ce_loss_neg
        
        # Apply to positive and negative samples
        pos_loss = targets * focal_loss_pos
        neg_loss = (1 - targets) * focal_loss_neg
        
        # Combine positive and negative losses
        loss = pos_loss + neg_loss
        
        # Apply class weights if available
        if self.class_weights is not None:
            if self.class_weights.device != loss.device:
                self.class_weights = self.class_weights.to(loss.device)
            
            # Weight each class
            loss = loss * self.class_weights.unsqueeze(0)
        
        # Scale by alpha and take mean
        loss = self.alpha * loss.mean()
        
        return loss
    
    def update_class_frequencies(self, class_frequencies):
        """Update class frequencies (useful for adaptive re-weighting)"""
        self.class_weights = self._calculate_class_weights(class_frequencies, self.beta)
        if hasattr(self, 'device'):
            self.class_weights = self.class_weights.to(self.device)
    
    def __repr__(self):
        return f"ClassBalancedFocalLoss(α={self.alpha}, γ={self.gamma}, β={self.beta})"


def calculate_ade20k_class_frequencies(dataset):
    """
    Calculate class frequencies for ADE20K dataset
    
    Args:
        dataset: ADE20K dataset instance
        
    Returns:
        list: Class frequencies for each of the 150 classes
    """
    print("Calculating class frequencies for ADE20K dataset...")
    
    class_counts = np.zeros(150)
    
    for i in range(len(dataset)):
        sample = dataset[i]
        labels = sample['labels'].numpy()
        class_counts += labels
    
    print(f"Class frequency statistics:")
    print(f"  Min frequency: {class_counts.min():.0f}")
    print(f"  Max frequency: {class_counts.max():.0f}")
    print(f"  Mean frequency: {class_counts.mean():.1f}")
    print(f"  Std frequency: {class_counts.std():.1f}")
    
    return class_counts.tolist()


# Convenience function for creating loss with ADE20K frequencies
def create_ade20k_class_balanced_focal_loss(train_dataset, 
                                           alpha=1.0, 
                                           gamma=2.0, 
                                           beta=0.9999):
    """
    Create Class-Balanced Focal Loss with ADE20K class frequencies
    
    Args:
        train_dataset: ADE20K training dataset
        alpha: Scaling factor
        gamma: Focusing parameter
        beta: Re-weighting factor
        
    Returns:
        ClassBalancedFocalLoss instance
    """
    class_frequencies = calculate_ade20k_class_frequencies(train_dataset)
    
    return ClassBalancedFocalLoss(
        alpha=alpha,
        gamma=gamma,
        beta=beta,
        class_frequencies=class_frequencies
    )


# Test function
if __name__ == "__main__":
    # Test with dummy data
    print("Testing Class-Balanced Focal Loss...")
    
    # Create loss function
    loss_fn = ClassBalancedFocalLoss(alpha=1.0, gamma=2.0, beta=0.9999)
    
    # Dummy data
    batch_size, num_classes = 4, 10
    inputs = torch.randn(batch_size, num_classes, requires_grad=True)
    targets = torch.randint(0, 2, (batch_size, num_classes)).float()
    
    # Forward pass
    loss = loss_fn(inputs, targets)
    print(f"Loss value: {loss.item():.4f}")
    
    # Backward pass
    loss.backward()
    print(f"Gradients computed successfully!")
    
    print("✅ Class-Balanced Focal Loss test passed!")