from models.commom import FeatureExtractor, CustomMobileNetV2

settings = {
    "default" : {
        "feat_dim": 256,
        "class_num": 28,
        "split_dim": 256
    },
    "parse": {
        "feat_dim": 384,
        "class_num": 28,
        "split_dim": 128
    }
}

def get_setup(configs):
    if configs.mode == "parse":
        setup = settings["parse"]
    else:
        setup = settings["default"]
    encoder_fns = [lambda: FeatureExtractor(input_channels=1, embedding_size=setup["feat_dim"]), 
                   lambda: CustomMobileNetV2(embedding_size=setup["feat_dim"])]
    classifier_dim = setup["split_dim"]
    output_dim = setup["class_num"]
    
    return encoder_fns, classifier_dim, output_dim