# encoding: utf-8
"""
@author:  l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F

from fastreid.utils.events import get_event_storage



def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2, eta=-1, q=None, truncate=2., auxiliary_loss=torch.tensor(0)):
    num_classes = pred_class_logits.size(1)

    if eps >= 0:
        smooth_param = eps
    else:
        # Adaptive label smooth regularization
        soft_label = F.softmax(pred_class_logits, dim=1)
        smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)

    log_probs = F.log_softmax(pred_class_logits, dim=1)
    with torch.no_grad():
        targets = torch.ones_like(log_probs)
        targets *= smooth_param / (num_classes - 1)
        targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))

    loss = (-targets * log_probs).sum(dim=1)
    if eta > 0:
        weights_loss = loss.data + auxiliary_loss.data
        weights = (weights_loss/eta).exp()
        store_weight = weights.clone()
        if truncate > 0:
            weights[weights>truncate] = truncate
        if q is None:
            weights /= torch.mean(weights)
        else:
            weights /= torch.mean(q)
        loss = loss * weights
        #return loss.mean()

    """
    # confidence penalty
    conf_penalty = 0.3
    probs = F.softmax(pred_class_logits, dim=1)
    entropy = torch.sum(-probs * log_probs, dim=1)
    loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
    """

    with torch.no_grad():
        non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)

    loss = loss.sum() / non_zero_cnt

    if eta > 0:
        return loss, weights, store_weight

    return loss
