# /usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F


class weighted_loss(object):
    def __init__(self, device, nn_loss=False):
        self.estimate_weights = None
        self.device = device
        self.nn_loss = nn_loss
        if nn_loss:
            print('Non negative loss')

        self.zero_one_loss = False
        self.class_num = 0

    def set_weights(self, estimate_priors):
        self.estimate_weights = estimate_priors.to(self.device)
        self.ori_estimate_priors = self.estimate_weights
        if self.nn_loss:
            self.estimate_weights = torch.clamp(self.estimate_weights, min=0)

    def set_class_num(self, cls_num):
        self.class_num = cls_num

    def get_weights(self):
        return self.ori_estimate_priors

    def __call__(self, outputs, target, index):
        log_soft = F.log_softmax(outputs.float(), 1)
        loss_vector = torch.zeros(len(log_soft), device=self.device)

        for i in range(self.class_num):
            loss_vector[target == i] = -log_soft[target == i, i]

        if index==None:
            loss_estimate = loss_vector.mean()
        else:
            priors_ = self.estimate_weights[index]
            loss_estimate = priors_.dot(loss_vector)
            loss_estimate = loss_estimate.mean()

        info = {
            'estimate': loss_estimate,
        }

        return loss_estimate
