
import torch
import torch.nn.functional as F

__all__ = ["cross_entropy", "accuracy", "mse", "kld", "gaussian_nll"]


def _apply_mask(tensor, mask):
    """
    Applies a mask to the input tensor. If no mask is provided, a mask of ones (True) with the same shape as the input tensor is used.

    Args:
        tensor (torch.Tensor): The tensor to apply the mask to.
        mask (torch.Tensor, optional): The mask to apply. If None, a mask of ones (True) is used.

    Returns:
        torch.Tensor: The masked tensor.
    """
    if mask is None:
        mask = torch.ones_like(tensor).bool()
    return tensor[mask]


def cross_entropy(input, target, valid_mask=None, reduction='none', label_smoothing=0.):
    """
    Computes the masked cross entropy loss.

    Args:
        input (torch.Tensor): Predicted logits.
        target (torch.Tensor): Ground truth labels.
        valid_mask (torch.Tensor, optional): Boolean mask to apply. Default is None.
        reduction (str, optional): Specifies the reduction to apply to the output. Default is 'none'.

    Returns:
        torch.Tensor: The computed cross entropy loss.
    """
    if valid_mask is None:
        losses = [F.cross_entropy(split_input, split_target, label_smoothing=0., reduction='none').view(-1)
                  for split_input, split_target in zip(input, target) if split_target.shape[-1] > 1]
        loss = torch.stack(losses, dim=-1)
    else:
        losses = [F.cross_entropy(_apply_mask(split_input, split_mask.squeeze(-1)),
                                  _apply_mask(split_target, split_mask.squeeze(-1)),
                                  label_smoothing=label_smoothing, reduction='none').view(-1)
                  for split_input, split_target, split_mask in zip(input, target, valid_mask.split(1, dim=1))
                  if split_target.shape[-1] > 1]
        loss = torch.cat(losses)

    if reduction == 'mean':
        loss = loss.mean().view(-1)

    return loss

def accuracy(input, target, valid_mask=None, reduction='none'):
    """
    Computes the masked accuracy.

    Args:
        input (torch.Tensor): Predicted logits.
        target (torch.Tensor): Ground truth one-hot encoded labels.
        valid_mask (torch.Tensor, optional): Boolean mask to apply. Default is None.
        reduction (str, optional): Specifies the reduction to apply to the output. Default is 'none'.

    Returns:
        torch.Tensor: The computed accuracy.
    """
    accuracies = [(split_input.argmax(dim=1) == split_target.argmax(dim=1)).float()
                  for split_input, split_target in zip(input, target)]
    acc = torch.stack(accuracies, dim=1)
    acc = _apply_mask(acc, valid_mask)

    if reduction == 'mean':
        acc = acc.mean().view(-1)

    return acc


def mse(input, target, valid_mask=None, reduction='none'):
    """
    Computes the masked mean squared error (MSE) loss.

    Args:
        input (torch.Tensor): Predicted values.
        target (torch.Tensor): Ground truth values.
        valid_mask (torch.Tensor, optional): Boolean mask to apply. Default is None.
        reduction (str, optional): Specifies the reduction to apply to the output. Default is 'none'.

    Returns:
        torch.Tensor: The computed MSE loss.
    """
    loss = F.mse_loss(_apply_mask(input, valid_mask), _apply_mask(target, valid_mask), reduction='none')
    # loss = _apply_mask(loss, mask)

    if reduction == 'mean':
        loss = loss.mean().view(-1)

    return loss


def kld(mu, logvar, valid_mask=None, reduction='none'):
    """
    Computes the masked Kullback-Leibler divergence (KLD) loss.

    Args:
        mu (torch.Tensor): Mean of the latent variable distribution.
        logvar (torch.Tensor): Log variance of the latent variable distribution.
        valid_mask (torch.Tensor, optional): Boolean mask to apply. Default is None.
        reduction (str, optional): Specifies the reduction to apply to the output. Default is 'none'.

    Returns:
        torch.Tensor: The computed KLD loss.
    """
    kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    loss = kld_loss.mean(dim=-1)
    loss = _apply_mask(loss, valid_mask)

    if reduction == 'mean':
        loss = loss.mean().view(-1)

    return loss


def gaussian_nll(mu, logvar, target, vaid_mask=None, reduction='none'):
    """
    Computes the Gaussian Negative Log-Likelihood loss without producing NaNs by clamping logvar.

    Args:
        mu (torch.Tensor): Predicted mean.
        logvar (torch.Tensor): Predicted log variance.
        target (torch.Tensor): Ground truth.
        vaid_mask (torch.Tensor, optional): Boolean mask. Default is None.
        reduction (str, optional): Specifies the reduction to apply ('none', 'mean', or 'sum'). Default is 'none'.

    Returns:
        torch.Tensor: Computed loss.
    """
    if vaid_mask is not None:
        return gaussian_nll(_apply_mask(mu, vaid_mask),
                            _apply_mask(logvar, vaid_mask),
                            _apply_mask(target, vaid_mask),
                            reduction=reduction)

    sigma2 = torch.exp(logvar)

    loss = 0.5 * torch.log(2 * torch.pi * sigma2) + (target - mu) ** 2 / (2 * sigma2)

    if reduction == 'mean':
        loss = loss.mean().view(-1)

    return loss
