from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch

from src.models.llama import get_llama_model
from src.models.mistral import get_mistral_model

def load_model_and_tokenizer(model_name_or_path, quant_config, forward_quant=False, device=None):
    device = "auto" if device is None else device

    if quant_config["name"] == "Identity":
        print("identity")
        model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                                     device_map=device,
                                                     torch_dtype=torch.half)
   
    else:
        if (quant_config["kwargs"] is not None) and ("quantizer_path" in quant_config["kwargs"]):
            quantizer_path = quant_config["kwargs"]["quantizer_path"]
            quantizer_path = quantizer_path.format(model_name=model_name_or_path.replace("/", "_"))
            quant_config["kwargs"]["quantizer_path"] = quantizer_path

        if "llama" in model_name_or_path.lower():
            model = get_llama_model(model_name_or_path, quant_config, forward_quant, device)
        elif "mistral" in model_name_or_path.lower():
            model = get_mistral_model(model_name_or_path, quant_config, forward_quant, device)
        else:
            raise NotImplementedError(f"Model {model_name_or_path} is not implemented")

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True
    )
    return model, tokenizer
