import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_hf_model_name(args):
    if args.model_name=="llama3.2":
        if args.model_size=="1B":
            return "meta-llama/Llama-3.2-1B-Instruct"
        elif args.model_size=="3B":
            return "meta-llama/Llama-3.2-3B-Instruct"
    elif args.model_name=="llama3.1":
        if args.model_size=="8B":
            return "meta-llama/Llama-3.1-8B-Instruct"
    elif args.model_name=="qwen2.5":
        if args.model_size=="3B":
            return "Qwen/Qwen2.5-3B-Instruct"
        elif args.model_size=="7B":
            return "Qwen/Qwen2.5-7B-Instruct"
    elif args.model_name=="qwen3":
        if args.model_size=="1.7B":
            return "Qwen/Qwen3-1.7B"
        if args.model_size=="4B":
            return "Qwen/Qwen3-4B"
    elif args.model_name=="opt":
        if args.model_size=="1.3B":
            return "facebook/opt-1.3b"
        elif args.model_size=="2.7B":
            return "facebook/opt-2.7b"
        elif args.model_size=="6.7B":
            return "facebook/opt-6.7b"
    raise ValueError(f"Unsupported model_name ({args.model_name}) or model_size ({args.model_size}).")


def load_model(args):
    hf_model_name = get_hf_model_name(args)

    model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=args.model_cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name, cache_dir=args.model_cache_dir)

    if args.disable_gpu:
        print("USE CPU!")
        device = "cpu"
    else:
        print("USE GPU!")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
    model.eval()
    
    return model, tokenizer, device