import torch


def compute_ce(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction='none')


def compute_gce(logits, labels, q):
    with torch.no_grad():
        prob = torch.nn.functional.softmax(logits, dim=1)
        target_prob = torch.gather(prob, 1, torch.unsqueeze(labels, 1)).squeeze()
        # modify gradient of cross entropy
        # loss_weight = (target_prob ** q) * q  # according to the implementation
        loss_weight = target_prob ** q  # according to the texts
    ce_loss = compute_ce(logits, labels)
    return loss_weight * ce_loss
