import torch

from src.methods.models.layers.surrogates.lodl.lodl_surrogate import LODLSurrogate


class DirectedQuadraticLODLSurrogate(LODLSurrogate):

    def __init__(self, y_true: torch.Tensor):

        super().__init__(y_true)

        bases = torch.rand((self._dim, self._dim, 4)) / (self._dim ** 2)
        self._L = torch.nn.Parameter(bases, requires_grad=True)

    def forward(self, y_hat: torch.Tensor) -> torch.Tensor:

        diff = (y_hat - self._y_true).unsqueeze(-2)

        l_matrix = self._get_basis(y_hat)

        output = torch.matmul(diff, torch.tril(l_matrix).clamp(-10, 10))
        output = torch.sum(output ** 2, dim=-1).squeeze(-1)

        return output

    def _get_basis(self, y_hat: torch.Tensor) -> torch.Tensor:

        direction = (y_hat > self._y_true).type(torch.int64)

        direction_col = direction.unsqueeze(-1)
        direction_row = direction.unsqueeze(-2)
        index = (direction_col + 2 * direction_row).unsqueeze(-1)

        bases = self._L.expand(*y_hat.shape[:-1], *self._L.shape)
        basis = bases.gather(-1, index).squeeze(-1)

        return basis
