from mmcv.runner import BaseModule
import torch.nn as nn
import torch
from mmdet.models.builder import LOSSES
import math


@LOSSES.register_module()
class GaussianNLLLoss(nn.Module):
    def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-6,):
        super().__init__()
        self.reduction = reduction
        self.eps = eps
        self.loss_fn = nn.GaussianNLLLoss(full=True, eps=self.eps, reduction=self.reduction)
        self.lambda_var = 0.1

    def forward(self, mu, log_sigma, target):
        sigma = torch.exp(log_sigma)#.clamp(min=self.eps)
        # min 1e-2, max 1e1
        var = (sigma ** 2).clamp(min=1e-1)
        # print(mu)
        # print(target)
        # print(var)
        loss = self.loss_fn(mu, target, var) - self.lambda_var * var.log().mean()     
        return loss