import torch

from src.methods.models.layers.surrogates.egl.egl_surrogate import EGLSurrogate


class DirectedWeightedMSEEGLSurrogate(EGLSurrogate):

    def __init__(self, y_true: torch.Tensor, min_val=1e-5):

        super().__init__(y_true)

        self._min_val = min_val

    @property
    def params_dim(self) -> int:
        return self._dim * 2

    def forward(self, y_hat: torch.Tensor, params: torch.Tensor) -> torch.Tensor:

        w_pos, w_neg = torch.split(params, [self._dim, self._dim], dim=-1)

        pos_weights = (y_hat > self._y_true.unsqueeze(0)).float() * w_pos.clamp(min=self._min_val)
        neg_weights = (y_hat < self._y_true.unsqueeze(0)).float() * w_neg.clamp(min=self._min_val)
        weights = pos_weights + neg_weights

        output = (y_hat - self._y_true) ** 2
        output = weights * output
        output = torch.sum(output, dim=1)

        return output
