from torch.nn import Linear, LayerNorm, ReLU, Sequential

def get_mlp(
    in_size,
    out_size,
    n_hidden,
    hidden_size=None,
    activation=ReLU,
    activate_last=True,
    layer_norm=True,
):
    arch = []
    l_in = in_size
    for l_idx in range(n_hidden):
        arch.append(Linear(l_in, hidden_size))
        arch.append(activation())
        l_in = hidden_size

    arch.append(Linear(l_in, out_size))

    if activate_last:
        arch.append(activation())

        if layer_norm:
            arch.append(LayerNorm(out_size))

    return Sequential(*arch)