import torch
from torch import Tensor, nn


class StaticWeighting(nn.Module):
    def __init__(self, weights: Tensor):
        super().__init__()
        if torch.any(weights < 0.0):
            raise ValueError("weights should all be >0.")
        self.register_buffer("weights", weights)

    def forward(self, losses: Tensor) -> Tensor:
        return (self.weights * losses).sum()
