import torch.nn as nn

nonlinearity_dict = {
    'relu': nn.ReLU(),
    'tanh': nn.Tanh(),
}

def build_mlp(layer_sizes, nonlinearity='relu', dropout=None, layernorm=False):
    assert nonlinearity_dict.get(nonlinearity)

    modules = nn.ModuleList()
    unpacked_sizes = []
    for block in layer_sizes:
        unpacked_sizes.extend([block[0]] * block[1])

    for k in range(len(unpacked_sizes)-1):
        if layernorm:
            modules.append(nn.LayerNorm(unpacked_sizes[k]))
        modules.append(nn.Linear(unpacked_sizes[k], unpacked_sizes[k+1]))
        if k < len(unpacked_sizes)-2:
            modules.append(nonlinearity_dict[nonlinearity])
            if dropout is not None:
                modules.append(nn.Dropout(dropout))

    return nn.Sequential(*modules)