import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_model_and_tokenizer(
    model_name: str,
    *,
    dtype: torch.dtype = torch.bfloat16,
):

    # LLAMA VARIANTS
    if model_name == "llama3-8b":
        model_path = # FILL IN
    elif model_name == "Llama_31_8b":
        model_path = # FILL IN
    # DEEPSEEK VARIANTS
    elif model_name == "Deepseek_Llama_8b":
        model_path = # FILL IN
    elif model_name == 'Deepseek_Qwen_14B':
        model_path = # FILL IN
    # GEMMA-3 VARIANTS
    elif model_name == "gemma-3-4b-it":
        model_path = # FILL IN
    elif model_name == "gemma-3-1b-it":
        model_path = # FILL IN
    elif model_name == "gemma-3-12b-it":
        model_path = # FILL IN
    # PHI-3 VARIANTS
    elif model_name == "phi-3-mini":
        model_path = # FILL IN
    elif model_name == "phi-3-medium":
        model_path = # FILL IN
    elif model_name == "phi-4":
        model_path = # FILL IN
    # QWEN 2.5 VARIANTS
    elif model_name == "qwen25_3B":
        model_path = # FILL IN
    elif model_name == "qwen25_7B":
        model_path = # FILL IN
    elif model_name == "qwen25_14B":
        model_path = # FILL IN
    # QWEN 3 VARIANTS
    elif model_name == "qwen3_4B":
        model_path = # FILL IN
    elif model_name == "qwen3_8B":
        model_path = # FILL IN
    elif model_name == "qwen3_14B":
        model_path = # FILL IN
    # MISTRAL VARIANTS
    elif model_name == "Mistral_7B":
        model_path = # FILL IN

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype= dtype, attn_implementation="eager", device_map="auto")
    return model, tokenizer