from .vbert_eurobert import *
from .vllama import *

def get_model_classes(model_type: str):
    """Get the model class based on the model name."""
    if model_type == 'vbert':
        return (
            VBertConfig,
            VBertModel,
            VBertForMaskedLM
        )
    elif model_type == 'vllama':
        return (
            VLlamaConfig,
            VLlamaModel,
            VLlamaForCausalLM
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")

def get_model_auto_map(model_type: str):
    if model_type == 'vbert':
        return VBERT_AUTO_MAP
    elif model_type == 'vllama':
        return VLLAMA_AUTO_MAP
    else:
        raise ValueError(f"Unknown model type: {model_type}")