import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F


class EurekaLoss(nn.Module):
    def __init__(self, opt, num_class, cls_num_list=None, ignore_idx=-100, cl_eps=1e-5):
        super(EurekaLoss, self).__init__()
        self.opt = opt
        # including opt.defer_start(default=1, help='the epoch to start encouragement,
        #   defualt setting is to start encouragement at the beginning'
        # --beta, default=0.9999,help='to adjust the encouragement for rare classes, the higher the more'
        # --bonus_gamma,default=-1, help='-1 means log(1-p),  >0 refers to power bonus'
        self.num_class = num_class
        self.ignore_idx = ignore_idx
        self.weight = self.cal_effective_weight(cls_num_list, beta=opt.beta)  # shape [num_classes]
        self.cl_eps = cl_eps
        self.epoch = 1

    def cal_effective_weight(self, cls_num_list, beta=0.9999):
        # cls_num_list frequency of each class, shape:[num_classes] to calculate weight
        effective_num = 1.0 - np.power(beta, cls_num_list)
        per_cls_weights = (1.0 - beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
        weight = torch.FloatTensor(per_cls_weights).cuda()
        return weight

    def forward(self, x, target, is_train=True):
        lprobs = F.log_softmax(x)
        mle_loss = F.nll_loss(lprobs, target, reduction='mean', ignore_index=self.ignore_idx)  # -y* log p
        org_loss = mle_loss
        if is_train and not (self.opt.defer_start and self.get_epoch() <= self.opt.defer_start):  # defer the rewarding
            probs = torch.exp(lprobs)
            bg = self.opt.bonus_gamma
            if bg > 0:
                bonus = -torch.pow(probs, bg)  # power bonus
            else:
                bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=self.cl_eps))  # likelihood bonus
            weight_courage = self.weight
            c_loss = F.nll_loss(
                -bonus * weight_courage,
                target.view(-1),
                reduction='mean',
                ignore_index=self.ignore_idx,
            )  # y*log(1-p)
            all_loss = mle_loss + c_loss
        else:
            all_loss = mle_loss
        return all_loss, org_loss

    def set_epoch(self, epoch):
        self.epoch = epoch + 1  # epoch count from 1

    def get_epoch(self):
        return self.epoch



