import torch

class ChannelNormalizedMSELoss(torch.nn.Module):
    """
    MSE loss that normalizes each channel separately, implemented with vectorized operations
    """
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, pred, target):
        # Calculate squared error - shape: [batch, channel, x, y, z]
        squared_error = (pred - target)**2
        
        # Calculate per-channel MSE by averaging over spatial dimensions and batch
        # shape: [channel]
        if self.reduction == 'mean':
            # Average over batch and spatial dimensions for each channel
            channel_mse = squared_error.mean(dim=(0, 2, 3, 4))
        else:  # 'sum'
            # Sum over batch and spatial dimensions for each channel
            channel_mse = squared_error.sum(dim=(0, 2, 3, 4))
        
        # Calculate per-channel variance of target
        # First, calculate mean over batch and spatial dimensions
        # shape: [channel]
        target_mean = target.mean(dim=(0, 2, 3, 4), keepdim=True)
        
        # Calculate variance (mean of squared differences)
        # shape: [channel]
        target_var = ((target - target_mean)**2).mean(dim=(0, 2, 3, 4))
        
        # Add small epsilon to avoid division by zero
        target_var = target_var + 1e-7
        
        # Normalize each channel's MSE by its target variance
        # shape: [channel]
        normalized_channel_mse = channel_mse / target_var
        
        # Average across channels
        loss = normalized_channel_mse.mean()
        
        return loss

# Use the vectorized implementation
loss_fn = ChannelNormalizedMSELoss()