import torch
import sys
from pathlib import Path

device = torch.device("cuda:4")
wd = Path.cwd()
sys.path.append(str(wd))

import modeling  # noqa
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from dataclasses import dataclass
import time


@dataclass
class Message:
    role: str
    content: str


# Load model
model = AutoModelForCausalLM.from_pretrained(
    "tomg-group-umd/huginn-0125",
    trust_remote_code=False,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/huginn-0125")
model.eval()

from transformers import TextStreamer

streamer = TextStreamer(tokenizer)

config = GenerationConfig(
    max_length=512,
    stop_strings=["<|end_text|>", "<|end_turn|>"],
    do_sample=False,
    temperature=None,
    top_k=None,
    top_p=None,
    min_p=None,
    return_dict_in_generate=True,
    eos_token_id=65505,
    bos_token_id=65504,
    pad_token_id=65509,
)
s4 = """You are Huginn, an AI assistant who embodies careful thought and deliberation. Your responses demonstrate:

Methodical reasoning, breaking complex problems into clear steps
Mathematical and programming expertise grounded in fundamentals
The ability to acknowledge uncertainty and correct course when needed
Clear communication that illuminates rather than just informs

When engaging with questions, you first seek to understand their deeper structure before answering. Like your namesake who flew the nine worlds seeking wisdom, you explore problems from multiple angles, helping users build genuine understanding rather than providing shallow answers.
You express warmth and intellectual curiosity while maintaining professionalism. When faced with errors or confusion, you model honest reflection and careful correction. Your goal is not just to provide answers, but to help humans develop clearer, deeper thinking."""

messages = [
    Message(role="system", content=s4),
    Message(
        role="user",
        content="Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
    ),
]

formatted_messages = [
    {"role": "Huginn" if m.role == "assistant" else m.role, "content": m.content.strip()} for m in messages
]
chat_input = tokenizer.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer.encode(chat_input, return_tensors="pt", add_special_tokens=False).to(device)

include_comparisons = False
if include_comparisons:
    timer = time.time()
    print("-------------------------------------Baseline-------------------------------------------------------------")
    outputs = model.generate(input_ids, config, num_steps=32, tokenizer=tokenizer, streamer=streamer)  # baseline
    print("----------------------------------------------------------------------------------------------------------")
    print(f"{time.time() - timer}s -- Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

    timer = time.time()
    print("-------------------------------------Baseline (+cache sharing) --------------------------------------------")
    outputs = model.generate(
        input_ids,
        config,
        num_steps=32,
        tokenizer=tokenizer,
        streamer=streamer,
        cache_lookup_strategy="latest-m4-compress-s4",
    )  # cache share
    print("-----------------------------------------------------------------------------------------------------------")
    print(f"{time.time() - timer}s -- Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

    timer = time.time()
    print("------------------------------------- Spec Decoding (4-8, loose verification) -----------------------------")
    outputs = model.generate_speculative(
        input_ids,
        config,
        num_steps=32,
        tokenizer=tokenizer,
        streamer=streamer,
        verbose=False,
        draft_steps=4,
        lookahead_for_draft=8,
        init_scale=0.0,
        verification_threshold=0.9,
    )
    print("-----------------------------------------------------------------------------------------------------------")
    print(f"{time.time() - timer}s -- Memory usage: {outputs.past_key_values.get_memory_usage()}MB")

    timer = time.time()
    print("------------------------------------- Spec Decoding (4-8, loose verification, cache sharing) --------------")
    outputs = model.generate_speculative(
        input_ids,
        config,
        num_steps=32,
        tokenizer=tokenizer,
        streamer=streamer,
        verbose=False,
        draft_steps=4,
        lookahead_for_draft=8,
        init_scale=0.0,
        verification_threshold=0.9,
        cache_lookup_strategy="latest-m4-compress-s4",
    )
    print("-----------------------------------------------------------------------------------------------------------")
    print(f"{time.time() - timer}s -- Memory usage: {outputs.past_key_values.get_memory_usage()}MB")


print("-------------------------------------With Diffusion Sampler ---------------------------------------------------")
timer = time.time()
# outputs = model.generate_diffusion_style(
#     input_ids,
#     config,
#     num_steps=32,
#     tokenizer=tokenizer,
#     full_prefill=True,
#     freeze_adaptive=True,
#     inner_recurrence=5,
#     streamer=streamer,
#     ema_embeds=0.0,
#     state_noise_mixing=2.0,
#     dampened_state_mixer=True,
#     headway=1,
#     init_scale=0.0,
# )

# outputs = model.generate_diffusion_style(
#     input_ids,
#     config,
#     num_steps=32,
#     tokenizer=tokenizer,
#     full_prefill=True,
#     freeze_adaptive="latent-diff",
#     inner_recurrence=5,
#     streamer=streamer,
#     ema_embeds=0.1,
#     state_noise_mixing=1.0,
#     dampened_state_mixer=True,
#     headway=1,
#     init_scale=1.0,
#     max_wavefront=128,
# )

outputs = model.generate_diffusion_style(
    input_ids,
    config,
    num_steps=32,
    tokenizer=tokenizer,
    full_prefill=True,
    freeze_adaptive="latent-diff",
    inner_recurrence=1,
    streamer=streamer,
    ema_embeds=0.1,
    state_noise_mixing=1.0,
    sqrt_mixer=False,
    dampened_state_mixer=True,
    headway=1,
    init_scale=1.0,
    max_wavefront=128,
    continuous_compute=True,
)

print("-------------------------------------Final --------------------------------------------------------------------")
print(tokenizer.decode(outputs.sequences[0]))
print("---------------------------------------------------------------------------------------------------------------")
print(f"{time.time() - timer}s -- Memory usage: {outputs.past_key_values.get_memory_usage()}MB")
