import flax.linen as nn
from typing import Sequence


class MLP(nn.Module):
    features: Sequence[int]
    activation: str = "relu"

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat)(x)
            if i != len(self.features) - 1:
                if self.activation == "relu":
                    x = nn.relu(x)
                elif self.activation == "tanh":
                    x = nn.tanh(x)
        return x
