import torch
from transformers import AutoModelForCausalLM

from src.monkeypatch import replace_llama
model = AutoModelForCausalLM.from_pretrained(
    "Meta-Llama-3-8B-Instruct",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    use_cache=True,
    attn_implementation="flash_attention_2"
)

replace_llama()