import torch


class LancerSurrogateLayer(torch.nn.Module):

    def __init__(self, input_dim: int, hidden_units: list[int], use_final_relu: bool = False):

        super().__init__()

        self._input_dim = input_dim
        self._output_dim = 1
        self._hidden_units = hidden_units
        self._use_final_relu = use_final_relu

        self._sequential_block = None

        self.initialize()

    def forward(self, x):

        y_hat = x[0]
        y_true = x[1]

        error = torch.square(y_hat - y_true)

        output = self._sequential_block(error)

        return output

    def initialize(self) -> None:

        self._sequential_block = torch.nn.Sequential()

        prev_feat = self._input_dim
        for i, n_units in enumerate(self._hidden_units):
            self._sequential_block.append(torch.nn.Linear(prev_feat, n_units))
            self._sequential_block.append(torch.nn.Tanh())
            prev_feat = n_units

        self._sequential_block.append(torch.nn.Linear(prev_feat, self._output_dim))
        if self._use_final_relu:
            self._sequential_block.append(torch.nn.ReLU())
