
from .resnet import ResNet18AutoEncoder
from .mlp import MLP_AutoEncoder


def build_networks(net_name, in_channels=3, mid_dim=128, mid_size=4, af_name=None):
    """Builds the corresponding autoencoder network."""

    implemented_networks = ('resnet18', 'mlp')
    assert net_name in implemented_networks

    ae_net = None

    if net_name == 'resnet18':
        ae_net = ResNet18AutoEncoder(in_channels=in_channels, mid_dim=mid_dim, mid_size=mid_size)

    if net_name == 'mlp':
        ae_net = MLP_AutoEncoder(input_dim=in_channels, num_hidden_nodes=mid_dim, af_name=af_name)

    return ae_net
