import torch
import torch.nn as nn

class NMSE(nn.Module):
    """Normalized Mean Squared Error loss"""
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, target):
        error = pred - target
        normalizer = torch.std(target) + 1e-8
        normalized_error = error / normalizer
        return torch.mean(normalized_error ** 2)