import torch
import torch.nn.functional as F
import torch.nn as nn

def gaussian_smooth_label(true_class, num_classes, sigma):
    x = torch.arange(num_classes)
    y = torch.exp(-(x - true_class)**2 / (2 * sigma**2))
    y /= y.sum()
    return y


def gaussian_smooth_kl_loss(batch_y_pred_cls_logits, target_cls, idx, y_hist, weights, sigma_div=3.0, loss_mode='kl'):
    """
    KL divergence loss using Gaussian smoothed labels
    """
    batch_size, num_classes = batch_y_pred_cls_logits.shape
    smoothed_labels = torch.zeros_like(batch_y_pred_cls_logits)
    assert weights.shape[0] == batch_size, "Weights tensor must have the same length as batch size"
    weights = weights.to(batch_y_pred_cls_logits.device)

   
    for i in range(batch_size):
        c = target_cls[i].item()
        partial_c = y_hist.y_pred_partial_cls[idx[i]]
        nonzero_idx = torch.nonzero(partial_c).squeeze()
        if nonzero_idx.dim() == 0:
            nonzero_idx = nonzero_idx.unsqueeze(0)

        min_index = nonzero_idx[0].item()
        max_index = nonzero_idx[-1].item()

        sigma = max(abs(c-min_index), abs(c-max_index)) / sigma_div
        smoothed_labels[i] = gaussian_smooth_label(c, num_classes, sigma)


    log_probs = F.log_softmax(batch_y_pred_cls_logits, dim=1)
    log_smoothed_labels = F.log_softmax(smoothed_labels, dim=1)


    if loss_mode == 'js':
        point_kl_losses = 0.5 * F.kl_div(
            log_probs,
            smoothed_labels,
            reduction='none'
        ).sum(dim=1) + 0.5 * F.kl_div(
            log_smoothed_labels,
            F.softmax(batch_y_pred_cls_logits, dim=1),
            reduction='none'
        ).sum(dim=1)
    else:
        point_kl_losses = F.kl_div(
            log_probs,
            smoothed_labels,
            reduction='none'
        ).sum(dim=1)
    weighted_point_losses = point_kl_losses * weights

    return weighted_point_losses.mean()



def partial_label_loss(raw_outputs, partial_Y, mask=None, smooth=-1):
    device = raw_outputs.device
    num_classes = raw_outputs.shape[1]

    sm_outputs = nn.Softmax(dim=1)(raw_outputs)
    onezero = torch.zeros(
        sm_outputs.shape[0], sm_outputs.shape[1], device=device)
    onezero[partial_Y > 0] = 1  # selection of positive labels

    confidence = 1.0 

    if smooth > 0:
        onezero = (1 - smooth) * onezero + smooth / num_classes

    if mask is not None:  
        orig_mask = mask
        mask = mask.unsqueeze(1).expand_as(onezero).to(device)
        selected_num = orig_mask.sum()
    else:
        mask = torch.ones(
            sm_outputs.shape[0], sm_outputs.shape[1], device=device)
        selected_num = sm_outputs.shape[0]

    sig_loss1 = - torch.log(sm_outputs + 1e-8)
    l1 = confidence * onezero * sig_loss1 * \
        mask / onezero.sum(dim=-1, keepdim=True)
    average_loss = torch.sum(l1) / selected_num  

    return average_loss