from models.classifier import Classifier, ClassifierBN

_pretrained_classifiers = {}
_pretrained_classifiers_BN = {}


def get_pretrained_classifier(num_classes, weights_file):
    global _pretrained_classifiers
    if weights_file not in _pretrained_classifiers:
        _pretrained_classifiers[weights_file] = Classifier(num_classes, weights_file)

    return _pretrained_classifiers[weights_file]


def get_pretrained_classifier_bn(num_classes, weights_file):
    global _pretrained_classifiers_BN
    if weights_file not in _pretrained_classifiers_BN:
        _pretrained_classifiers_BN[weights_file] = ClassifierBN(num_classes, weights_file)

    return _pretrained_classifiers_BN[weights_file]
