from cache import dLLMCache, dLLMCacheConfig
from hooks import register_cache_Dream
from dataclasses import asdict
from transformers import AutoModel, AutoTokenizer
import torch

prompt_interval_steps = 100
gen_interval_steps = 7
transfer_ratio = 0.25
use_cache = True
device = "cuda"
max_new_tokens = 256
steps = 256

model = (
    AutoModel.from_pretrained(
        "Dream-org/Dream-v0-Instruct-7B",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    .to(device)
    .eval()
)
tokenizer = AutoTokenizer.from_pretrained(
    "Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True
)
print("*" * 66)
print(
    f"**  Answer Length: {max_new_tokens}  |  Sampling Steps: {steps}  |  Use_cache {use_cache}"
)
print("*" * 66)
if use_cache:
    dLLMCache.new_instance(
        **asdict(
            dLLMCacheConfig(
                prompt_interval_steps=prompt_interval_steps,
                gen_interval_steps=gen_interval_steps,
                transfer_ratio=transfer_ratio,
            )
        )
    )
    register_cache_Dream(model, "model.layers")

user_input = input("Enter your question: ")

m = [{"role": "user", "content": user_input}]
user_input = tokenizer.apply_chat_template(
    m, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer(user_input)["input_ids"]
attention_mask = tokenizer(user_input)["attention_mask"]
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).to(device).unsqueeze(0)
feature_cache = dLLMCache()
feature_cache.reset_cache(input_ids.shape[1])
generation_ids = model.diffusion_generate(
    input_ids,
    attention_mask=attention_mask,
    max_new_tokens=max_new_tokens,
    output_history=False,
    return_dict_in_generate=True,
    steps=steps,
    temperature=0.2,
    top_p=0.95,
).sequences[:, input_ids.shape[1] :]


answer = tokenizer.batch_decode(generation_ids, skip_special_tokens=True)[0]
print(f"Dream's reply: {answer}")
print("-----------------------------------------------------------------------")
