from cache import dLLMCache, dLLMCacheConfig
from hooks import register_cache_LLaDA
from dataclasses import asdict
from transformers import AutoModel, AutoTokenizer
import torch
from utils import generate

prompt_interval_steps = 100
gen_interval_steps = 7
transfer_ratio = 0.25
use_cache = True
device = "cuda"
gen_length = 256
steps = 256

model = (
    AutoModel.from_pretrained(
        "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
    )
    .to(device)
    .eval()
)
tokenizer = AutoTokenizer.from_pretrained(
    "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True
)
print("*" * 66)
print(
    f"**  Answer Length: {gen_length}  |  Sampling Steps: {steps}  |  Use_cache {use_cache}"
)
print("*" * 66)

if use_cache:
    dLLMCache.new_instance(
        **asdict(
            dLLMCacheConfig(
                prompt_interval_steps=prompt_interval_steps,
                gen_interval_steps=gen_interval_steps,
                transfer_ratio=transfer_ratio,
            )
        )
    )
    register_cache_LLaDA(model, "model.transformer.blocks")

user_input = input("Enter your question: ")

m = [{"role": "user", "content": user_input}]
user_input = tokenizer.apply_chat_template(
    m, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer(user_input)["input_ids"]
attention_mask = tokenizer(user_input)["attention_mask"]
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
feature_cache = dLLMCache()
feature_cache.reset_cache(input_ids.shape[1])
generation_ids = generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    model=model,
    steps=steps,
    gen_length=steps,
    block_length=steps,
)
answer = tokenizer.batch_decode(generation_ids, skip_special_tokens=True)[0]
print(f"LLaDA's reply: {answer}")
print("-----------------------------------------------------------------------")
