import torch
import torch.nn as nn
import numpy as np 
from torch.nn import functional as F

class BalancedSoftmaxLoss(nn.Module):
    def __init__(self, cls_num_list):
        super().__init__()
        cls_prior = cls_num_list / sum(cls_num_list)
        self.log_prior = torch.log(cls_prior).unsqueeze(0)
        # self.min_prob = 1e-9
        # print(f'Use BalancedSoftmaxLoss, class_prior: {cls_prior}')

    def forward(self, logits, labels):
        adjusted_logits = logits + self.log_prior
        label_loss = F.cross_entropy(adjusted_logits, labels)

        return label_loss
    
class ClassBalancedSoftmax(nn.Module):
    """
    https://arxiv.org/abs/1901.05555
    """
    def __init__(self, cls_num_list, num_class=10, beta=0.9):
        super(ClassBalancedSoftmax, self).__init__()
        self.beta = beta

        self.counts_cls = cls_num_list
        self.counts_cls = nn.Parameter(torch.from_numpy(np.array(self.counts_cls).astype('float32')), 
                                       requires_grad =False).cuda()

        self.w = self.calc_weight(self.beta) if beta is not None else None

        return

    def __count_per_class(self, labels, num_class):
        unique_labels, count = np.unique(labels, return_counts=True)
        c_per_cls = np.zeros(num_class)
        c_per_cls[unique_labels] = count
        return c_per_cls

    def calc_weight(self, beta):
        """
        Args:
            beta : float or tensor(batch size, 1)
        """
        # effective number
        ef_Ns = (1 - torch.pow(beta, self.counts_cls)) / (1 - beta)

        # weight
        w = 1 / ef_Ns
        # normalize
        if len(w.size()) == 1:
            #WN = torch.mean(w * self.counts_cls)
            W = torch.sum(w)
        else:
            #WN = torch.mean(w * self.counts_cls, dim=1, keepdim=True)
            W = torch.sum(w, dim=1, keepdim=True)
        #N = torch.mean(self.counts_cls)
        C = self.counts_cls.size()[0]
        #w = w * N / WN
        w = w * C / W
        return w
    
    def forward(self, input, label, beta=None):
        """
        Args:
            beta : shape (batch size, 1) or (1, 1) in training, (1, 1) in test
        """
        if beta is None:
            w = self.w[label].unsqueeze(1) # (batch size, 1)
        else:
            w = self.calc_weight(beta) # (batch size, num class) or (1, num class)
            if w.size()[0] == 1 and label.size()[0] != 1:
                w = w.expand(label.size()[0], w.size()[1])
            w = torch.gather(w, -1, label.unsqueeze(1)) # (batch size, 1)

        logp = F.log_softmax(input, dim=-1) # (batch size, num class)
        logp = torch.gather(logp, -1, label.unsqueeze(1)) # (batch size, 1)

        loss = - w * logp
        return loss