import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import math

class RFCELoss(nn.Module):
    def __init__(self, class_freq, alpha=0.3, epsilon=1e-8):
        """
        Ricci Flow Adjusted Cross-Entropy Loss for multi-label classification
        
        Args:
            class_freq (torch.Tensor): Class frequencies [C]
            alpha (float): Curvature adjustment strength
            epsilon (float): Small value for numerical stability
        """
        super().__init__()
        self.alpha = alpha
        self.epsilon = epsilon
        
        # Register buffer for class frequencies
        self.register_buffer('class_freq', class_freq)
        self.freq = class_freq
        # Precompute curvature values
        self.max_freq = torch.max(class_freq)
        self.kappa = torch.log(class_freq / self.max_freq + epsilon)
        
    def forward(self, logits, targets, reduction='mean'):
        """
        Compute RFACE loss
        
        Args:
            logits (torch.Tensor): Predicted logits [N, C]
            targets (torch.Tensor): Target labels [N, C]
            reduction (str): Loss reduction method ('none', 'mean', 'sum')
            
        Returns:
            torch.Tensor: RFACE loss
        """
        # Compute probabilities
        p = torch.sigmoid(logits)
        # Compute BCE gradient component
        targets = F.one_hot(targets, num_classes=2)
        grad_bce = p - targets  # ∂L_BCE/∂z [N, C]
        
        # Compute logit adjustment
        delta_z = -self.alpha * self.kappa.to(p.device) * grad_bce
        
        # Apply adjustment to logits
        adjusted_logits = logits + delta_z
        
        # Compute BCE loss with adjusted logits
        loss = F.binary_cross_entropy_with_logits(
            adjusted_logits, 
            targets.float(), 
            reduction='none',
            weight=(sum(self.freq) / self.freq).to(p.device)
        )
        
        # Apply reduction
        if reduction == 'mean':
            return loss.mean()
        elif reduction == 'sum':
            return loss.sum()
        elif reduction == 'none':
            return loss
        else:
            raise ValueError(f"Invalid reduction mode: {reduction}")
