import torch
import torch.nn as nn

def multiply_along_dim(tensor_a, tensor_b, dim):

    # reshape to keep dimension 'dim'
    shape = [1] * tensor_a.dim()
    shape[dim] = tensor_b.size(0)
    tensor_b_reshaped = tensor_b.reshape(shape)

    # multiply with correctly reshaped tensor
    result = tensor_a * tensor_b_reshaped
    return result

class WeightedMSELoss(nn.Module):
    def __init__(self, weights, dim=0):
        super(WeightedMSELoss, self).__init__()
        self.weights = weights
        self.dim = dim

        # Normalize weights to sum to 1 along the specified dimension
        self.weights = self.weights / self.weights.sum()

    def forward(self, input, target):
        # Compute squared error
        squared_error = (input - target) ** 2

        weighted_squared_error = multiply_along_dim(squared_error, self.weights, dim=self.dim)
        loss = weighted_squared_error.sum(dim=self.dim)

        return loss.mean()