import torch
from torch import Tensor, nn


class BernoulliLoss(nn.Module):
    def __init__(self, eps: float):
        super().__init__()
        self.eps = eps
        # self.log_2pi = 1.8379

    def forward(self, p: Tensor, n_gt: Tensor):
        p = p.clip(min=self.eps, max=1.0 - self.eps)
        mu = torch.sum(p, dim=-1)
        var = torch.sum(p * (1 - p), dim=-1)
        loss = torch.square(n_gt - mu) / var + torch.log(var)
        return 0.5 * loss.mean()
