import torch
import torch.nn as nn
import torch.nn.functional as F
    
def entropy(x, q=1, batch_wise=False):
    """
    Helper function to compute the entropy over the batch
        entropy = - p * log(p)
    input: batch w/ shape [b, num_classes]
    output: entropy value [is ideally -log(num_classes)]
    """
    EPS = 1e-5
    x_ = torch.clamp(x, min=EPS)
    b = x_ * torch.log(x_ / q)

    if len(b.size()) == 2:  # Sample-wise entropy
        if batch_wise:
            return - b.sum(dim=1)
        else:
            return - b.sum(dim=1).mean()
    elif len(b.size()) == 1:  # Distribution-wise entropy
        return - b.sum()
    else:
        raise ValueError('Input tensor is %d-Dimensional' % (len(b.size())))

def PEB_loss(prob, peb_kl_w, kldu_w, batch_wise=False):
    '''
    Sparse-Diverse Regularization Loss, SDR loss
    entropy + KL Divergence
    '''    
    # Function to calculate KL Divergence from uniform distribution
    def kl_divergence_uniform(probabilities, kldu_w=1e-1):
        uniform = torch.ones_like(probabilities) / len(probabilities)
        alpha = torch.exp(entropy(probabilities, batch_wise=batch_wise)) * kldu_w
        kldu = torch.sum(probabilities * torch.log2(probabilities / uniform), dim=1)
        if batch_wise:
            return kldu
        else:
            return (kldu).mean()
    # print('SE: {:.4f} | KL: {:.4f}'.format(entropy(prob), 9.5e-4 * kl_divergence_uniform(prob)))
    en = entropy(prob, batch_wise=batch_wise)
    return en +  peb_kl_w *  kl_divergence_uniform(prob, kldu_w=kldu_w)

def mutual_information_loss(inputs, labels, num_classes, device):
    """
    Compute the mutual information loss based on the formula:
        MI_loss = H(E[p(y|x)]) - E[H(p(y|x))]
    """
    # Convert labels to one-hot encoded target distributions
    one_hot_targets = torch.zeros(len(labels), num_classes).to(device)
    one_hot_targets.scatter_(1, labels.view(-1, 1), 1)  # Fill 1 at the specified indices

    kl_loss = nn.KLDivLoss(reduction="batchmean")
    inputs = F.log_softmax(inputs, dim=1)
    target = F.softmax(one_hot_targets, dim=1)
    mi_loss = kl_loss(inputs, target)
    return mi_loss

def average_patch_entropy(input_tensor, sdr=False, sdr_kl_w=9.5e-4, kldu_w=1e-1, batch_wise=False, args=None):
    '''
    Args:
        sdr (Boolean): if True, use SDR loss  
    '''
    # Split the input tensor by patches along the num_patches dimension
    patches = torch.split(input_tensor, 1, dim=1)  # Assumes input_tensor is batch_size x num_patches x num_classes
    # Calculate entropy for each patch and collect the results in a list
    if sdr:
        patch_entropies = torch.stack([PEB_loss(patch.squeeze(dim=1), peb_kl_w=sdr_kl_w, kldu_w=kldu_w, batch_wise=batch_wise) for patch in patches])
    else:
        patch_entropies = torch.stack([entropy(patch.squeeze(dim=1), batch_wise=batch_wise) for patch in patches])
    # Calculate the average of patch entropies
    average_entropy = patch_entropies.sum() / (input_tensor.shape[0] * input_tensor.shape[1])
    return average_entropy
