from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizerFast,
    PreTrainedTokenizer,
    BitsAndBytesConfig
)
import torch


# Configuring the model with quantization for efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
    )

def load_model_and_tokenizer(model_name: str):
    model = _load_model(model_name)
    tokenizer = _load_tokenizer(model_name)
    return model, tokenizer

def _load_model(model_name: str):
    # Load LM
    if ("Llama-2-13b" in model_name) or ("Llama-2-70b" in model_name):
        model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            use_flash_attention_2=True,
            device_map="auto"
        )        
    model.eval()
    return model


def _load_tokenizer(model_name: str) -> PreTrainedTokenizer:
        # Load tokenizer
        if "Llama-2" in model_name:
            tokenizer = LlamaTokenizerFast.from_pretrained(
                model_name, torch_dtype=torch.float16, padding_side="left"
            )
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.period_token_id = tokenizer("Game over.", add_special_tokens=False).input_ids[-1]

        return tokenizer