from torch import nn

class MSELossWrapper(nn.Module):
    """
    A wrapper class for Mean Squared Error (MSE) Loss calculation.

    This class provides a convenient interface for computing Mean Squared Error loss
    in regression or multi-output prediction tasks. It flattens the input tensors
    to handle multi-dimensional prediction scenarios, such as time series or 
    multi-step forecasting.

    Attributes:
        loss_fn (nn.MSELoss): The underlying Mean Squared Error loss function.
            Always uses 'mean' reduction to compute the average squared differences.
    """
    def __init__(self, reduction='mean'):
        """
        Initialize the MSELossWrapper with a Mean Squared Error loss function.

        Args:
            reduction (str, optional): Reduction method for loss calculation.
                Defaults to 'mean'.
        """
        super(MSELossWrapper, self).__init__()
        self.loss_fn = nn.MSELoss(reduction="mean")
        self.num_outputs = 12  
    
    def forward(self, pred, batch):
          """
        Compute Mean Squared Error (MSE) loss between predictions and ground truth.

        This method flattens both predictions and ground truth to handle 
        multi-dimensional input scenarios. Assumes a specific output shape 
        of 12 elements per sample.

        Args:
            pred (torch.Tensor): Model predictions.
                Expected to be a tensor with shape compatible with flattening 
                to (batch_size * time_steps, 12).
            batch (dict): Batch dictionary containing ground truth labels.
                Must have a 'label' key with ground truth values.

        Returns:
            torch.Tensor: Computed Mean Squared Error loss.
                A scalar value representing the average squared differences 
                between predictions and ground truth.
        """
        y = batch['label']
        y_flat = y.reshape(-1, self.num_outputs)  # Shape: (batch_size * time_steps, num_outputs=12)
        pred_flat = pred.reshape(-1, self.num_outputs) # Shape: (batch_size * time_steps, num_outputs=12)
        loss = self.loss_fn(pred_flat, y_flat)
        return loss
