from .mlp import *
from .convnet import *
from .resnet import *
from .lstm import *
from .hf_model import *

def get_model(model_config, peft_config=None):
    model_type = model_config.pop("type")
    if model_type == "mlp":
        model = MLP(**model_config)
    elif model_type == "convnet2":
        model = ConvNet2(**model_config)
    elif model_type == "convnet2_resize":
        model = ConvNet2Resize(**model_config)
    elif model_type == "convnet3":
        model = ConvNet3(**model_config)
    elif model_type == "resnet18":
        model = ResNet18(**model_config)
    elif model_type == "lstm":
        model = LSTM(**model_config)
    elif model_type == "llama":
        model_name_or_path = "meta-llama/Llama-2-7b-hf"
        model = HFModel(model_name_or_path, peft_config, **model_config)
    elif model_type == "llama-chat":
        model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
        model = HFModel(model_name_or_path, peft_config, **model_config)
    else:
        raise NotImplementedError
    # else:
    #     model = LMAdapter(model_name_or_path, **model_config)
    return model