import torch

mse_fn = torch.nn.MSELoss()
per_element_mse_fn = torch.nn.MSELoss(reduction='none')

def per_element_loss_fn(pred, target, mode, element_mask = None):
    per_element_loss = per_element_mse_fn(pred, target).mean(dim = -1)

    if element_mask is not None:
        per_element_loss[element_mask] = 0

    return per_element_loss

metric_name = {
    'sine': 'psnr',
    'tanh': 'psnr',
    'gaussian': 'psnr',
    'relu': 'psnr',
}

loss_fn = {
    'sine': mse_fn,
    'tanh': mse_fn,
    'gaussian': mse_fn,
    'relu': mse_fn,
}


def batch_loss_fn(pred, target, mode, batch_mask = None, norm = 1.0):
    """Compute Loss between two batches of precictions and targets
    while preserving the batch dimension (per batch element loss)
    Args:
        pred (torch.Tensor): Shape (batch_size, N, Dy)
        target (torch.Tensor):
            if `mode` is semseg --> Shape (batch_size, N, Dy)
            else --> Shape (batch_size, N, 1)
    Returns:
        Loss Tensor of shape (batch_size, )
    """
    per_element_loss = per_element_mse_fn(pred, target).div(norm)
    batch_loss = per_element_loss.view(pred.shape[0], -1).mean(dim = 1)

    if batch_mask is not None:
        batch_loss = batch_loss[batch_mask]
    return batch_loss


@torch.no_grad()
def batch_metric_fn(pred, target, mode, batch_mask = None, norm = 1.0, eps = 1e-9, mask_value='except'):
    """Compute Pre-defined Metric between two batches of predictions and targets
    while preserving the batch dimension (per batch element metric)
    Args:
        pred (torch.Tensor): Shape (batch_size, L, Dy)
        target (torch.Tensor): Shape (batch_size, L, Dy)
    Returns:
        Metric Tensor of shape (batch_size,)
    """
    # Use PSNR metric
    # pred shape (batch_size, N, Dy)
    # target shape (batch_size, N, Dy)
    if batch_mask is not None and mask_value == 'except':
        pred = pred[batch_mask]
        target = target[batch_mask]
    peak = 1.0
    noise = (pred - target).pow(2).div(norm).mean(1, keepdim=True) # (batch_size, 1, Dy)
    # batchwise_mse = noise.mean([1, 2])
    batchwise_channelwise_psnr = 10 * torch.log10(peak / (noise + eps)) # (batch_size, 1, Dy)
    batchwise_psnr = batchwise_channelwise_psnr.mean([1, 2])

    if batch_mask is not None and mask_value == 'zero':
        batchwise_psnr.masked_fill_(~batch_mask, 0)

    return batchwise_psnr


def mse2psnr(mse):
    """Computes PSNR from MSE, assuming the MSE was calculated between signals
    lying in [0, 1].
    Args:
        mse (torch.Tensor or float):
    """
    return -10.0 * torch.log10(mse)
