# encoding: utf-8
"""
@author:  l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F

from fastreid.utils.events import get_event_storage



def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2, eta=-1, q=None, truncate=2., iter_=0, eta_iter=int(1e11)):
    num_classes = pred_class_logits.size(1)

    if eps >= 0:
        smooth_param = eps
    else:
        # Adaptive label smooth regularization
        soft_label = F.softmax(pred_class_logits, dim=1)
        smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)

    log_probs = F.log_softmax(pred_class_logits, dim=1)
    with torch.no_grad():
        targets = torch.ones_like(log_probs)
        targets *= smooth_param / (num_classes - 1)
        targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))

    loss = (-targets * log_probs).sum(dim=1)
    if eta > 0 and iter_ >= eta_iter:
        weights = (loss.data/eta).exp()
        if truncate > 0:
            weights[weights>truncate] = truncate

        store_weight = weights.clone()
        #weights[weights<1.0] = 1.0
        if q is None:
            weights /= torch.mean(weights)
        elif q.shape[0] // pred_class_logits.shape[0] < iter_ - 184800: #eta_iter:
            weights /= torch.mean(q)
        else:
            weights /= torch.mean(weights)
        loss = loss * weights
    else:
        weights = (loss.data/eta).exp()
        store_weight = weights.clone()
        #return loss.mean()

    with torch.no_grad():
        non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)

    loss = loss.sum() / non_zero_cnt

    if eta > 0:
        return loss, weights, store_weight

    return loss
