import torch
import torch.nn as nn

class StdLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(StdLoss, self).__init__()
        self.reduction = reduction  # Specifies whether to return the mean or sum of the loss

    def forward(self, x, target):
        # Compute the standard deviation of the predicted outputs (x)
        std_dev = torch.std(x, dim=1)
        
        # Compute the mean squared error between the predicted outputs (x) and the target
        mse_loss = torch.mean((x - target) ** 2, dim=1)
        
        # Combine both losses: penalize both high std dev and poor predictions
        loss = mse_loss + std_dev
        
        # Return the final loss based on the reduction type
        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss
