from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

HF_TOKEN = ''
assert HF_TOKEN != '', 'HuggingFace token is required to download models'

def load_model(model_name='llama-chat', size=7, device='cuda', quantize=False):
    if model_name == 'llama-chat':
        model_name = f'meta-llama/Llama-2-{size}b-chat-hf'
    elif model_name == 'llama-3':
        model_name = f'meta-llama/Meta-Llama-3-8B'
    elif model_name == 'mistral':
        model_name = f'mistralai/Mistral-7B-v0.1'
    elif model_name == 'llama':
        model_name = f'meta-llama/Llama-2-{size}b-hf'
    else:
        raise ValueError("Model not supported")
    if quantize:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True
        )
    else:
        quantization_config = None
    tokenizer = AutoTokenizer.from_pretrained(model_name,
                                              use_auth_token=HF_TOKEN,
                                              cache_dir='models')
    if device != 'cuda':
        return tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 quantization_config=quantization_config, device_map=device,
                                                 use_auth_token=HF_TOKEN,
                                                 cache_dir='models')
    return model, tokenizer

