"""
An example of LLM generation using the kv-cache utilities.
Requires 'transformers==4.47.0'.
"""

import torch

from monkey_patch.utils import enable_optimal_brain_kv, enable_optimal_brain_kv_flashattn2
from cache_utils import *
from utils import load_kv_cache, load_model_and_tokenizer, seed_everything


# model_name = "Qwen2/Qwen2-0.5B-Instruct"
model_name = "meta-llama/Llama-3.1-8B-Instruct"
model, tokenizer = load_model_and_tokenizer(
    model_name_or_path=model_name,
    precision="bf16",
    flash_attn=False
)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("model and tokenizer loaded")

enable_optimal_brain_kv(model, 'llama' if 'llama' in model_name.lower() else 'qwen2')
print("OBCache enabled")

prompt = "A special magic number is hidden within the following text. Make sure to memorize it. I will quiz you about the number afterwards.\\nJuly 2010What hard liquor, cigarettes, heroin, and crack have in common is that they're all more concentrated forms of less addictive predecessors. Most if not all the things we describe as addictive are. And the scary thing is, the process that created them is accelerating.We wouldn't want to stop it. It's the same process that cures diseases: technological progress. One of the special magic numbers for worried-colt is: 8930103. Technological progress means making things do more of what we want. When the thing we want is something we want to want, we consider technological progress good. If some new technique makes solar cells x% more efficient, that seems strictly better. \\nWhat is the special magic number for worried-colt mentioned in the provided text?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Required template for Qwen2
past_key_values = load_kv_cache(
    method = "h2o", # "obc_value_p2", "obc_key_p2", "obc_value_key_p2"
    num_recent = 16, num_heavy = 240,
    # recent_ratio=0.01, heavy_ratio=0.19,
    decode_evict=True
)
print("Cache Eviction Method:\n", past_key_values)

seed_everything(42)

model_inputs = tokenizer([prompt], return_tensors="pt")
generated_ids = model.generate(
    model_inputs.input_ids.to(model.device),
    attention_mask=model_inputs.attention_mask,
    max_new_tokens=30,
    past_key_values=past_key_values,
    do_sample=False,
    temperature=1.0,
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("generated response: \n", response)