import torch

from src.methods.models.layers.surrogates.lodl.lodl_surrogate import LODLSurrogate


class QuadraticLODLSurrogate(LODLSurrogate):

    def __init__(self, y_true: torch.Tensor):

        super().__init__(y_true)

        basis = torch.rand((self._dim, self._dim)) / (self._dim ** 2)
        self._L = torch.nn.Parameter(basis, requires_grad=True)

    def forward(self, y_hat: torch.Tensor) -> torch.Tensor:

        diff = y_hat - self._y_true

        output = torch.matmul(diff, torch.tril(self._L).clamp(-100, 100))
        output = torch.sum(output ** 2, dim=1)

        return output
