import torch

from src.methods.models.layers.surrogates.lodl.lodl_surrogate import LODLSurrogate


class DirectedWeightedMSELODLSurrogate(LODLSurrogate):

    def __init__(self, y_true: torch.Tensor, min_val=1e-3):

        super().__init__(y_true)

        self._min_val = min_val
        self._W_pos = torch.nn.Parameter(min_val + torch.rand_like(self._y_true), requires_grad=True)
        self._W_neg = torch.nn.Parameter(min_val + torch.rand_like(self._y_true), requires_grad=True)

    def forward(self, y_hat: torch.Tensor) -> torch.Tensor:

        pos_weights = (y_hat > self._y_true.unsqueeze(0)).float() * self._W_pos.clamp(min=self._min_val)
        neg_weights = (y_hat < self._y_true.unsqueeze(0)).float() * self._W_neg.clamp(min=self._min_val)
        weights = pos_weights + neg_weights

        output = (y_hat - self._y_true) ** 2
        output = weights * output
        output = torch.sum(output, dim=1)

        return output
