from abc import ABC, abstractmethod

import torch


class EGLSurrogate(ABC, torch.nn.Module):

    def __init__(self, y_true: torch.Tensor):

        super().__init__()

        self._y_true = y_true
        self._dim = y_true.shape[0]

    @property
    @abstractmethod
    def params_dim(self) -> int:
        pass

    @abstractmethod
    def forward(self, y_hat: torch.Tensor, params: torch.Tensor) -> torch.Tensor:
        pass
