# This code is referenced from 
# https://github.com/facebookresearch/astmt/
# 
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# 
# License: Attribution-NonCommercial 4.0 International

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
import numpy as np


class BalancedCrossEntropyLoss(Module):
    """
    Balanced Cross Entropy Loss with optional ignore regions
    criterion = BalancedCrossEntropyLoss(size_average=True)
    """

    def __init__(self, size_average=True, batch_average=True, pos_weight=None):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.size_average = size_average
        self.batch_average = batch_average
        self.pos_weight = pos_weight

    def forward(self, output, label, void_pixels=None):
        assert (output.size() == label.size())
        labels = torch.ge(label, 0.5).float()

        # Weighting of the loss, default is HED-style
        if self.pos_weight is None:
            num_labels_pos = torch.sum(labels)
            num_labels_neg = torch.sum(1.0 - labels)
            num_total = num_labels_pos + num_labels_neg
            w = num_labels_neg / num_total
        else:
            w = self.pos_weight

        output_gt_zero = torch.ge(output, 0).float()
        loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log(
            1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero)))

        loss_pos_pix = -torch.mul(labels, loss_val)
        loss_neg_pix = -torch.mul(1.0 - labels, loss_val)

        if void_pixels is not None and not self.pos_weight:
            w_void = torch.le(void_pixels, 0.5).float()
            loss_pos_pix = torch.mul(w_void, loss_pos_pix)
            loss_neg_pix = torch.mul(w_void, loss_neg_pix)
            num_total = num_total - torch.ge(void_pixels, 0.5).float().sum()
            w = num_labels_neg / num_total

        loss_pos = torch.sum(loss_pos_pix)
        loss_neg = torch.sum(loss_neg_pix)

        final_loss = w * loss_pos + (1 - w) * loss_neg

        if self.size_average:
            final_loss /= float(np.prod(label.size()))
        elif self.batch_average:
            final_loss /= label.size()[0]

        return final_loss




# class Normalize(nn.Module):
#     def __init__(self):
#         super(Normalize, self).__init__()

#     def forward(self, bottom):
#         qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12
#         top = bottom.div(qn)

#         return top


# class NormalsLoss(Module):
#     """
#     L1 loss with ignore labels
#     normalize: normalization for surface normals

#     NormalsLoss(normalize=True, size_average=True, norm=1)
#     """
#     def __init__(self, size_average=True, normalize=False, norm=1):
#         super(NormalsLoss, self).__init__()

#         self.size_average = size_average

#         if normalize:
#             self.normalize = Normalize()
#         else:
#             self.normalize = None

#         if norm == 1:
#             # print('Using L1 loss for surface normals')
#             self.loss_func = F.l1_loss
#         elif norm == 2:
#             # print('Using L2 loss for surface normals')
#             self.loss_func = F.mse_loss
#         else:
#             raise NotImplementedError

#     def forward(self, out, label, ignore_label=255):
#         assert not label.requires_grad
#         mask = (label != ignore_label)
#         n_valid = torch.sum(mask).item()

#         if self.normalize is not None:
#             out_norm = self.normalize(out)
#             loss = self.loss_func(torch.masked_select(out_norm, mask), torch.masked_select(label, mask), reduction='sum')
#         else:
#             loss = self.loss_func(torch.masked_select(out, mask), torch.masked_select(label, mask), reduction='sum')

#         if self.size_average:
#             if ignore_label:
#                 ret_loss = torch.div(loss, max(n_valid, 1e-6))
#                 return ret_loss
#             else:
#                 ret_loss = torch.div(loss, float(np.prod(label.size())))
#                 return ret_loss

#         return loss
