import torch


class FeedForwardLayer(torch.nn.Module):

    def __init__(self, input_dim: int, output_dim: int, hidden_units: list[int], use_final_relu: bool = False):

        super().__init__()

        self._input_dim = input_dim
        self._output_dim = output_dim
        self._hidden_units = hidden_units
        self._use_final_relu = use_final_relu

        self._sequential_block = None

        self.initialize()

    def forward(self, x):

        output = self._sequential_block(x)

        return output

    def initialize(self) -> None:

        self._sequential_block = torch.nn.Sequential()

        prev_feat = self._input_dim
        for i, n_units in enumerate(self._hidden_units):
            self._sequential_block.append(torch.nn.Linear(prev_feat, n_units))
            self._sequential_block.append(torch.nn.ReLU())
            prev_feat = n_units

        self._sequential_block.append(torch.nn.Linear(prev_feat, self._output_dim))
        if self._use_final_relu:
            self._sequential_block.append(torch.nn.ReLU())
