import torch
from torch import Tensor, nn


class GaussianMixtureLoss(nn.Module):
    def __init__(self, eps: float):
        super().__init__()
        self.eps = eps
        self.log_2pi = 1.8379

    def forward(self, p: Tensor, x: Tensor, u: Tensor, x_gt: Tensor, mask_gt: Tensor):
        C = x[:, None] - x_gt[:, :, None]
        C = torch.square(C / u[:, None]) + torch.log(u[:, None]) + self.log_2pi
        C = -0.5 * torch.sum(C, dim=-1)
        p = torch.nn.functional.normalize(p, dim=-1, p=1.0, eps=self.eps)
        C = C + torch.log(p + self.eps)[:, None]
        C = torch.logsumexp(C, dim=-1)
        C = C.masked_fill(~mask_gt, torch.nan)
        C = C.nanmean(-1)
        loss = -1.0 * C.nanmean()
        return loss
