import torch

from src.methods.models.layers.surrogates.egl.egl_surrogate import EGLSurrogate


class WeightedMSEEGLSurrogate(EGLSurrogate):

    def __init__(self, y_true: torch.Tensor, min_val=1e-5):

        super().__init__(y_true)

        self._min_val = min_val

    @property
    def params_dim(self) -> int:
        return self._dim

    def forward(self, y_hat: torch.Tensor, params: torch.Tensor) -> torch.Tensor:

        w = params

        output = (y_hat - self._y_true) ** 2
        output = w.clamp(min=self._min_val) * output
        output = torch.sum(output, dim=1)

        return output
