from __future__ import print_function
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
from models.BayesianModels.BayesianAlexNet import BBBAlexNet
from models.BayesianModels.BayesianLeNet import BBBLeNet
from models.BayesianModels.Bayesian_1FC import BBB1Fc
from models.BayesianModels.ResNet18 import ResNetClassifier

from models.BayesianModels.Bayesian_2FC import BBB2Fc


def getModel(net_type, inputs, outputs, dim,  priors, layer_type, activation_type,neurons = None):
    if (net_type == '1fc'):
        return BBB1Fc(outputs, inputs, neurons, dim, priors, layer_type, activation_type)
    if (net_type == '2fc'):
        return BBB2Fc(outputs, inputs,dim,  priors, layer_type, activation_type)
    if (net_type == 'lenet'):
        return BBBLeNet(outputs, inputs, priors, layer_type, activation_type)
    elif (net_type == 'alexnet'):
        return BBBAlexNet(outputs, inputs, priors, layer_type, activation_type)
    elif (net_type == '3conv3fc'):
        return BBB3Conv3FC(outputs, inputs, priors, layer_type, activation_type)
    elif (net_type == 'resnet'):
        return ResNetClassifier(outputs, inputs)
    else:
        raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC / ResNet')
