import torch

from src.methods.models.layers.surrogates.lodl.lodl_surrogate import LODLSurrogate


class WeightedMSELODLSurrogate(LODLSurrogate):

    def __init__(self, y_true: torch.Tensor, min_val=1e-3):

        super().__init__(y_true)

        self._min_val = min_val
        self._W = torch.nn.Parameter(min_val + torch.rand_like(self._y_true), requires_grad=True)

    def forward(self, y_hat: torch.Tensor) -> torch.Tensor:

        output = (y_hat - self._y_true) ** 2
        output = self._W.clamp(min=self._min_val) * output
        output = torch.sum(output, dim=1)

        return output
