# encoding: utf-8
"""
@author:  l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F

from fastreid.utils.events import get_event_storage



def groupdro(pred_class_logits, gt_classes, domain_labels, eps, alpha=0.2, eta=-1, q=None):
    num_classes = pred_class_logits.size(1)

    if eps >= 0:
        smooth_param = eps
    else:
        # Adaptive label smooth regularization
        soft_label = F.softmax(pred_class_logits, dim=1)
        smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)

    log_probs = F.log_softmax(pred_class_logits, dim=1)
    with torch.no_grad():
        targets = torch.ones_like(log_probs)
        targets *= smooth_param / (num_classes - 1)
        targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))

    loss = (-targets * log_probs).sum(dim=1)

    for i in range(5):
        losses = loss[domain_labels == i]
        if losses.shape[0] == 0:
            continue
        q[i] *= (1e-2*losses.mean().data).exp()
    q /= q.sum()
    for i in range(5):
        loss[domain_labels == i] *= q[i]

    with torch.no_grad():
        non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)

    loss = loss.sum() * 5 / non_zero_cnt

    return loss, q
