from transformers import AutoModelForCausalLM
import torch


def create_model(name, config=None, model_weights=None):
    if name == "finetuned-Qwen/Qwen2.5-7B-Instruct":
        try:
            if model_weights is None:
                model = AutoModelForCausalLM.from_pretrained(
                    config["hf_cache"],
                    torch_dtype=torch.bfloat16,
                    device_map=config["device_map"],
                )
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_weights,
                    torch_dtype=torch.bfloat16,
                    device_map=config["device_map"],
                )

        except:
            raise ValueError(f"Cannot load model name : {name}")

    if name == "Qwen/Qwen2.5-7B-Instruct":
        try:
            model = AutoModelForCausalLM.from_pretrained(
                name,
                cache_dir=config["hf_cache"],
                torch_dtype=torch.bfloat16,
                device_map=config["device_map"],
                #attn_implementation="flash_attention_2"
            )
        except:
            raise ValueError(f"Cannot load model name : {name}")

    return model
