import torch
import torch.nn.functional as F
import math

def SoftBinnedECE(logits, labels, n_bins=15, temperature=0.1, p=2):
    bin_centers = torch.linspace(0.0, 1.0, n_bins)

    batch_size = logits.size(0)
    device = logits.device

    probs = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(probs, 1)


    accuracies = predictions.eq(labels).float()

    bin_centers = bin_centers.to(device)

    diff = confidences.view(-1, 1) - bin_centers.view(1, -1)
    logits = - (diff ** 2) / temperature

    bin_weights = F.softmax(logits, dim=1) 


    S_j = bin_weights.sum(dim=0) + 1e-8  
    C_j = (confidences.unsqueeze(1) * bin_weights).sum(dim=0) / S_j  
    A_j = (accuracies.unsqueeze(1) * bin_weights).sum(dim=0) / S_j 

    # 计算校准误差
    errors = (A_j - C_j).abs().pow(p)
    weighted_errors = (S_j / batch_size) * errors  # 公式(11)

    ece = weighted_errors.sum().pow(1 / p)
    return ece



def SoftAvUCLoss(logits, labels, temperature: float = 0.1,theta: float = 0.5,use_deprecated_v0: bool = False) -> torch.Tensor:   


    probabilities = F.softmax(logits, dim=1)
    

    predicted_labels = torch.argmax(probabilities, dim=1)
    accuracies = torch.eq(predicted_labels, labels).to(torch.float32)


    batch_size, num_classes = probabilities.shape
    

    uniform_probabilities = torch.full_like(probabilities, 1.0 / num_classes)
    log_safe_probabilities = (1.0 - 1e-6) * probabilities + 1e-6 * uniform_probabilities
    
    log_probabilities = torch.log(log_safe_probabilities)
    entropies = -torch.sum(log_safe_probabilities * log_probabilities, dim=1)

    entmax = math.log(num_classes)


    if use_deprecated_v0:
  
        xus = -((entropies - entmax)**2)
        xcs = -(entropies**2)
        qucs = F.softmax(torch.stack([xus, xcs], dim=1), dim=1)
        qus = qucs[:, 0]
        qcs = qucs[:, 1]
    else:

        def soft_uncertainty(e, temp=1.0, theta=0.5):
         
            normalized_e = e / entmax
        
            logit = (1 / temp) * torch.log(
                normalized_e * (1 - theta) / ((1 - normalized_e) * theta + 1e-6) + 1e-6
            )
            return torch.sigmoid(logit)

        qus = soft_uncertainty(entropies, temp=temperature, theta=theta)
        qcs = 1.0 - qus
    

    is_accurate_mask = accuracies > 0.5
    is_inaccurate_mask = ~is_accurate_mask

    tanh_entropies = torch.tanh(entropies)
    
    # NAC: Accurate-Certain
    nac_values = qcs * (1.0 - tanh_entropies)
    nac_diff = torch.sum(nac_values[is_accurate_mask])
    
    # NAU: Accurate-Uncertain
    nau_values = qus * tanh_entropies
    nau_diff = torch.sum(nau_values[is_accurate_mask])
    
    # NIC: Inaccurate-Certain
    nic_values = qcs * (1.0 - tanh_entropies)
    nic_diff = torch.sum(nic_values[is_inaccurate_mask])
    
    # NIU: Inaccurate-Uncertain
    niu_values = qus * tanh_entropies
    niu_diff = torch.sum(niu_values[is_inaccurate_mask])

    denominator = torch.maximum(nac_diff + niu_diff, torch.tensor(1e-6, device=logits.device))
    avuc_loss = torch.log(1.0 + (nau_diff + nic_diff) / denominator)

    return avuc_loss