from src.models.model_gpt_j import ProbingGPTJ


def get_model(model_name, low_resource_mode, device):
    if model_name == "EleutherAI/gpt-j-6b":
        model = ProbingGPTJ(low_resource_mode, device)
    else:
        raise NotImplementedError()
    return model
