import logging

from llm_router.model.utils import print_trainable_parameters

logger = logging.getLogger(__name__)

def get_model(config):
    if config.name == "linear_mf_predictor":
        from llm_router.model.linear_mf_predictor import LinearMFPredictor
        model = LinearMFPredictor(config)
    elif config.name == "mlp_mf_predictor":
        from llm_router.model.mlp_mf_predictor import MLPMFPredictor
        model = MLPMFPredictor(config)
    elif config.name == "graph_predictor":
        from llm_router.model.graph_predictor import GraphPredictor
        model = GraphPredictor(config)
    elif config.name == "indep_cnp_predictor":
        from llm_router.model.indep_cnp_prefictor import IndepCNPPredictor
        model = IndepCNPPredictor(config)
    elif config.name == "nested_cnp_predictor":
        from llm_router.model.nested_cnp_predictor import NestedCNPPredictor
        model = NestedCNPPredictor(config)
    elif config.name == "linear_mf_classifier":
        from llm_router.model.linear_mf_classifier import LinearMFClassifier
        model = LinearMFClassifier(config)
    elif config.name == "mlp_mf_classifier":
        from llm_router.model.mlp_mf_classifier import MLPMFClassifier
        model = MLPMFClassifier(config)
    elif config.name == "graph_classifier":
        from llm_router.model.graph_classifier import GraphClassifier
        model = GraphClassifier(config)
    elif config.name == "indep_cnp_classifier":
        from llm_router.model.indep_cnp_classifier import IndepCNPClassifier
        model = IndepCNPClassifier(config)
    elif config.name == "nested_cnp_classifier":
        from llm_router.model.nested_cnp_classifier import NestedCNPClassifier
        model = NestedCNPClassifier(config)
    else:
        raise NotImplementedError()
    
    logger.info(f"\n\n{model}\n\n")
    print_trainable_parameters(model)
    
    return model