import torch

from src.methods.models.layers.surrogates.egl.egl_surrogate import EGLSurrogate


class QuadraticEGLSurrogate(EGLSurrogate):

    def __init__(self, y_true: torch.Tensor):

        super().__init__(y_true)

    @property
    def params_dim(self) -> int:
        return self._dim * self._dim

    def forward(self, y_hat: torch.Tensor, params: torch.Tensor) -> torch.Tensor:

        l = params.view(self._dim, self._dim)

        diff = y_hat - self._y_true

        output = torch.matmul(diff, torch.tril(l).clamp(-100, 100))
        output = torch.sum(output ** 2, dim=1)

        return output
