import flax.linen as nn

class FeedForward(nn.Module):
    config: dict

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.config['final_hidden_1'])(x)
        x = nn.silu(x)
        x = nn.LayerNorm(epsilon=1e-6)(x)
        x = nn.Dense(self.config['final_hidden_2'])(x)
        x = nn.silu(x)
        x = nn.LayerNorm(epsilon=1e-6)(x)
        return nn.Dense(1)(x)
