from .network import MNIST_RIM, MNIST_LeNet, HAR_RIM, ADULT_RIM, BANK_RIM, CREDIT_RIM, MNIST_MLP, MNIST_MLP_Autoencoder, DCC_LeNet, DCC_LeNet_Autoencoder, MNIST_LeNet_Autoencoder

def build_network(net_name):
    """Builds the neural network."""

    implemented_networks = ('credit_rim','adult_rim', 'bank_rim', 'har_rim', 'mnist_LeNet', 'mnist_rim', 'mnist_mlp', 'DCC_LeNet')
    assert net_name in implemented_networks

    net = None
    if net_name == 'credit_rim':
        net = CREDIT_RIM()

    if net_name == 'bank_rim':
        net = BANK_RIM()

    if net_name == 'adult_rim':
        net = ADULT_RIM()

    if net_name == 'har_rim':
        net = HAR_RIM()

    if net_name == 'mnist_rim':
        net = MNIST_RIM()

    if net_name == 'mnist_mlp':
        net = MNIST_MLP()

    if net_name == 'mnist_LeNet':
        net = MNIST_LeNet()

    if net_name == 'DCC_LeNet':
        net = DCC_LeNet()

    return net


def build_autoencoder(net_name):
    """Builds the corresponding autoencoder network."""

    implemented_networks = ('mnist_LeNet', 'mnist_mlp', 'DCC_LeNet')
    assert net_name in implemented_networks

    ae_net = None

    if net_name == 'mnist_mlp':
        ae_net = MNIST_MLP_Autoencoder()

    if net_name == 'mnist_LeNet':
        ae_net = MNIST_LeNet_Autoencoder()

    if net_name == 'DCC_LeNet':
        ae_net = DCC_LeNet_Autoencoder()

    return ae_net
