import torch

from src.methods.models.layers.surrogates.egl.egl_surrogate import EGLSurrogate


class DirectedQuadraticEGLSurrogate(EGLSurrogate):

    def __init__(self, y_true: torch.Tensor):

        super().__init__(y_true)

    @property
    def params_dim(self) -> int:
        return self._dim * self._dim * 4

    def forward(self, y_hat: torch.Tensor, params: torch.Tensor) -> torch.Tensor:

        l = params.view(self._dim, self._dim, 4)

        diff = (y_hat - self._y_true).unsqueeze(-2)

        l_matrix = self._get_basis(y_hat, l)

        output = torch.matmul(diff, torch.tril(l_matrix).clamp(-10, 10))
        output = torch.sum(output ** 2, dim=-1).squeeze(-1)

        return output

    def _get_basis(self, y_hat: torch.Tensor, l: torch.Tensor) -> torch.Tensor:

        direction = (y_hat > self._y_true).type(torch.int64)

        direction_col = direction.unsqueeze(-1)
        direction_row = direction.unsqueeze(-2)
        index = (direction_col + 2 * direction_row).unsqueeze(-1)

        bases = l.expand(*y_hat.shape[:-1], *l.shape)
        basis = bases.gather(-1, index).squeeze(-1)

        return basis
