import torch
from model_transformer.utils import one_hot_encode
from typing import Callable, Optional, Union, Tuple

def softmax(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Apply softmax to input tensor along specified dimension.
    
    Args:
        x: Input tensor
        dim: Dimension along which to apply softmax
        
    Returns:
        Tensor with softmax applied
    """
    return torch.nn.functional.softmax(x, dim=dim)

class CWCE:
    """
    Class-Wise Calibration Error (CWCE) metric.
    
    CWCE is an estimator of the marginal calibration error applying binning to estimate
    the observed probabilities corresponding to a confidence range.
    
    Warning:
        - CWCE is dependent on sample size and may be unreliable for small datasets
        - CWCE under any binning scheme is biased and underestimates the true CE
        - CWCE only measures calibration quality, not discrimination
        - No systematic studies exist for behavior with imbalanced data
    
    Args:
        num_bins (int): Number of bins for confidence binning. Should be adjusted based on sample size.
        p (int): Norm parameter for the calibration error (typically p=1)
        class_weights (Optional[torch.Tensor]): Optional weights for each class for unequal interest
                                              across classes. Shape should be (num_classes,)
        from_logits (bool): If True, inputs to __call__ are expected to be logits, not probabilities
    """
    def __init__(self, num_bins: int = 15, p: int = 1, class_weights: Optional[torch.Tensor] = None, from_logits: bool = False):
        if num_bins < 1:
            raise ValueError("Number of bins must be positive")
        self.num_bins = num_bins
        self.p = p
        self.class_weights = class_weights
        self.from_logits = from_logits
    
    def __call__(self, inputs: torch.Tensor, labels: torch.Tensor, return_per_class: bool = True):
        """
        Compute the Class-Wise Calibration Error.
        
        Args:
            inputs: Model logits (if from_logits=True) or probabilities of shape (N, C) 
                   where N is number of samples and C is number of classes
            labels: True labels of shape (N,) or (N, C) if one-hot encoded
            return_per_class: If True, returns both aggregated CWCE and per-class CEs
        
        Returns:
            float or tuple: If return_per_class is False, returns aggregated CWCE value in range [0, 1]
                          If True, returns (aggregated_cwce, per_class_ces)
        """
        # Input validation and conversion
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.tensor(inputs, dtype=torch.float32)
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels)
        
        # Convert labels to one-hot if needed
        if len(labels.shape) == 1:
            num_classes = inputs.shape[1]
            labels = torch.eye(num_classes)[labels]
        
        # Convert logits to probabilities if needed
        probs = softmax(inputs, dim=1) if self.from_logits else inputs
        num_classes = probs.shape[1]
        
        # Verify or initialize class weights
        if self.class_weights is None:
            self.class_weights = torch.ones(num_classes) / num_classes
        else:
            if len(self.class_weights) != num_classes:
                raise ValueError("Number of class weights must match number of classes")
            self.class_weights = self.class_weights / self.class_weights.sum()
        
        # Create bin boundaries with small epsilon for last bin
        bin_boundaries = torch.linspace(0, 1, self.num_bins + 1)
        bin_boundaries[-1] += 1e-8
        
        per_class_ces = torch.zeros(num_classes)
        N = len(labels)
        
        # Calculate per-class calibration errors
        for c in range(num_classes):
            class_probs = probs[:, c]
            class_labels = labels[:, c]
            
            class_ce = 0.0
            for m in range(self.num_bins):
                # Find samples in current bin
                bin_mask = (class_probs >= bin_boundaries[m]) & (class_probs < bin_boundaries[m + 1])
                bin_samples = bin_mask.sum()
                
                if bin_samples > 0:
                    # Calculate |Bcm|/N * ||Accuracy - Confidence||p
                    bin_accuracy = class_labels[bin_mask].mean()
                    bin_confidence = class_probs[bin_mask].mean()
                    class_ce += (bin_samples / N) * torch.abs(bin_accuracy - bin_confidence)**self.p
            
            per_class_ces[c] = class_ce
        
        # Calculate weighted average across classes
        cwce = (per_class_ces * self.class_weights).sum()
        
        # Apply p-th root if p > 1
        if self.p > 1:
            cwce = cwce**(1/self.p)
            per_class_ces = per_class_ces**(1/self.p)
        
        if return_per_class:
            return cwce.item(), per_class_ces.tolist()
        return cwce.item()
    
class NLL:
    def __init__(self, reduction='mean', eps=1e-7, from_logits=False):
        """
        Initialize NLL Loss Calculator
        
        Args:
            reduction (str): Specifies the reduction to apply to the output ('none', 'mean', 'sum')
            eps (float): Small constant for numerical stability
            from_logits (bool): If True, inputs are expected to be logits, not probabilities
        """
        self.nll_loss = torch.nn.NLLLoss(reduction=reduction)
        self.eps = eps
        self.from_logits = from_logits
    
    def __call__(self, inputs, targets):
        """
        Calculate NLL Loss from logits or probabilities
        
        Args:
            inputs: Tensor of shape (batch_size, num_classes) containing logits or probabilities
            targets: Tensor of shape (batch_size,) containing target class indices
        
        Returns:
            Negative Log Likelihood Loss
        """
        # Convert to tensor if needed
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.tensor(inputs, dtype=torch.float32)
        if not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, dtype=torch.long)
            
        # Convert logits to probabilities if needed
        probabilities = softmax(inputs, dim=1) if self.from_logits else inputs
            
        # Clamp probabilities to prevent log(0)
        probabilities = torch.clamp(probabilities, self.eps, 1.0)
        
        log_probabilities = torch.log(probabilities)
        return self.nll_loss(log_probabilities, targets)
    
class RBS:
    """
    Root Brier Score (RBS) metric.
    
    RBS is the square root of the mean squared error between predicted probabilities
    and one-hot encoded true labels. It represents a robust upper bound of the
    canonical calibration error.
    
    Value Range: [0, √2], where lower values indicate better calibration.
    
    Warning:
        - It is not clear how tight the upper bound is, especially for models with low accuracy
        - Should be reported together with ECE/ECE_KDE
    
    Reference:
        Gruber and Buettner, 2022
    """
    def __init__(self, from_logits=False):
        """
        Initialize RBS calculator
        
        Args:
            from_logits (bool): If True, inputs are expected to be logits, not probabilities
        """
        self.from_logits = from_logits
        
    def __call__(self, inputs: torch.Tensor, labels: torch.Tensor) -> float:
        """
        Compute the Root Brier Score.
        
        Args:
            inputs: Model logits (if from_logits=True) or probabilities of shape (N, C)
                   where N is number of samples and C is number of classes
            labels: True labels of shape (N,) or (N, C) if one-hot encoded
            
        Returns:
            float: The RBS value in range [0, √2]
        """
        # Input validation and conversion
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.tensor(inputs, dtype=torch.float32)
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels)
            
        # Convert logits to probabilities if needed
        probabilities = softmax(inputs, dim=1) if self.from_logits else inputs
            
        # Convert labels to one-hot if needed
        if len(labels.shape) == 1:
            num_classes = probabilities.shape[1]
            labels = torch.eye(num_classes)[labels]
            
        # Calculate squared differences between predictions and true labels
        squared_diff = (probabilities - labels) ** 2
        
        # Calculate mean over both samples and classes
        mse = squared_diff.mean()
        
        # Take square root
        rbs = torch.sqrt(mse)
        
        return rbs.item()
    
class SKCE:
    """
    Squared Kernel Calibration Error (SKCE) metric.
    
    SKCE is a kernel-based calibration metric that measures the calibration error
    without binning, making it more reliable for smaller datasets compared to ECE/CWCE.
    
    Args:
        kernel_fn (Optional[Callable]): Kernel function for probability vectors. 
                                      Defaults to Gaussian kernel with σ=0.1
        class_weights (Optional[torch.Tensor]): Optional weights for each class
        from_logits (bool): If True, inputs are expected to be logits, not probabilities
    """
    def __init__(self, 
                 kernel_fn: Optional[Callable] = None, 
                 class_weights: Optional[torch.Tensor] = None,
                 from_logits: bool = False):
        self.class_weights = class_weights
        self.kernel_fn = kernel_fn or self._default_gaussian_kernel
        self.from_logits = from_logits
        
    def _default_gaussian_kernel(self, x: torch.Tensor, y: torch.Tensor, sigma: float = 0.1) -> torch.Tensor:
        """Default Gaussian kernel with σ=0.1"""
        dist = torch.cdist(x, y, p=2)
        return torch.exp(-dist**2 / (2 * sigma**2))
    
    def __call__(self, inputs: torch.Tensor, labels: torch.Tensor) -> float:
        """
        Compute the Squared Kernel Calibration Error.
        
        Args:
            inputs: Model logits (if from_logits=True) or probabilities of shape (N, C)
                   where N is number of samples and C is number of classes
            labels: True labels of shape (N,) or (N, C) if one-hot encoded
            
        Returns:
            float: The SKCE value in range [0, 1]
        """
        # Input validation and conversion
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.tensor(inputs, dtype=torch.float32)
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels)
        
        # Convert logits to probabilities if needed
        probabilities = softmax(inputs, dim=1) if self.from_logits else inputs
        
        # Convert labels to one-hot if needed
        if len(labels.shape) == 1:
            num_classes = probabilities.shape[1]
            labels = torch.eye(num_classes)[labels]
        
        n = len(labels)
        
        # Compute kernel matrix
        kernel_matrix = self.kernel_fn(probabilities, probabilities)
        
        # Compute difference between outcomes and predictions
        diff = labels - probabilities
        
        # Compute SKCE using upper triangular part
        skce = 0.0
        for i in range(n):
            for j in range(i+1, n):
                # Compute contribution from sample pair (i,j)
                matrix_kernel = kernel_matrix[i,j] * torch.eye(probabilities.shape[1])
                contribution = torch.mm(
                    torch.mm(diff[i:i+1], matrix_kernel),
                    diff[j:j+1].t()
                )
                skce += contribution.item()
        
        # Normalize by number of pairs
        prefactor = (n * (n-1)) // 2
        skce /= prefactor
        
        return skce
    
