import torch
import time
from fastchat.model import get_conversation_template
from transformers import AutoTokenizer
from eagle.model.ea_model import EaModel

def get_prompt(message):
    conv = get_conversation_template("llama")
    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

def warmup_model(model, tokenizer, use_eagenerate=True):
    prompt = "Describe the formation and evolution of galaxies in the universe."
    input_ids = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).input_ids.cuda()

    if use_eagenerate:
        _ = model.eagenerate(
            input_ids.cuda(),
            temperature=0,
            max_new_tokens=1024,
            log=True,
            is_llama3=True,
        )
    else:
        _ = model.naivegenerate(
            input_ids.cuda(),
            temperature=0,
            max_new_tokens=1024,
            log=True,
            is_llama3=True,
        )

def run_eagle_and_time(model, tokenizer, prompt):
    inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
    input_len = inputs["input_ids"].shape[-1]
    torch.cuda.synchronize()
    start = time.perf_counter()
    output_ids, new_token_count, _ = model.eagenerate(
        inputs["input_ids"].cuda(),
        temperature=0.0,
        max_new_tokens=1024,
        log=True,
        is_llama3=True,
    )
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    generated_ids = output_ids[0][input_len:]  # Only new tokens
    output = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return output, elapsed, new_token_count

def run_baseline_and_time(model, tokenizer, prompt):
    inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
    input_len = inputs["input_ids"].shape[-1]
    torch.cuda.synchronize()
    start = time.perf_counter()
    output_ids, new_token_count, _ = model.naivegenerate(
        inputs["input_ids"].cuda(),
        temperature=0.0,
        max_new_tokens=1024,
        log=True,
        is_llama3=True,
    )
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    generated_ids = output_ids[0][input_len:]  # Only new tokens
    output = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return output, elapsed, new_token_count

def main():
    your_message = "Explain the principles of quantum mechanics and their applications in modern technology."
    prompt = get_prompt(your_message)

    print("Running EAGLE...")
    eagle_model = EaModel.from_pretrained(
        base_model_path="/home/merdogan/EAGLE/DeepSeek-R1-Distill-Llama-8B",
        ea_model_path="yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B",
        total_token=60,
        depth=5,
        top_k=10,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )
    eagle_model.eval()
    eagle_tokenizer = eagle_model.get_tokenizer()

    warmup_model(eagle_model, eagle_tokenizer, use_eagenerate=True)

    eagle_output, eagle_time, eagle_tokens = run_eagle_and_time(eagle_model, eagle_tokenizer, prompt)
    print("\n[EAGLE Output]:", eagle_output)
    print(f"[EAGLE Time]: {eagle_time:.3f} sec")
    print(f"[EAGLE Tokens/sec]: {eagle_tokens / eagle_time:.2f}")

    print("\nRunning Baseline (naivegenerate)...")
    base_model = EaModel.from_pretrained(
        base_model_path="/home/merdogan/EAGLE/DeepSeek-R1-Distill-Llama-8B",
        ea_model_path="yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B",
        total_token=60,
        depth=5,
        top_k=10,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )
    base_model.eval()
    base_tokenizer = base_model.get_tokenizer()

    warmup_model(base_model, base_tokenizer, use_eagenerate=False)

    base_output, base_time, base_tokens = run_baseline_and_time(base_model, base_tokenizer, prompt)
    print("\n[Baseline Output]:", base_output)
    print(f"[Baseline Time]: {base_time:.3f} sec")
    print(f"[Baseline Tokens/sec]: {base_tokens / base_time:.2f}")

    speedup = base_time / eagle_time if eagle_time > 0 else float("inf")
    print(f"\n⚡ [EAGLE Speedup over Baseline]: {speedup:.2f}x")

if __name__ == "__main__":
    main()



"""import torch
import time
from fastchat.model import get_conversation_template
from transformers import AutoModelForCausalLM, AutoTokenizer
from eagle.model.ea_model import EaModel

def get_prompt(message):
    conv = get_conversation_template("llama")
    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

def warmup_model(model, tokenizer, use_eagenerate=True):
    prompt = "Describe the formation and evolution of galaxies in the universe."
    input_ids = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).input_ids.cuda()

    if use_eagenerate:
        _ = model.eagenerate(
            input_ids.cuda(),
            temperature=0,
            max_new_tokens=1024,
            log=True,
            is_llama3=True,
        )
    else:
        _ = model.generate(
            input_ids,
            max_new_tokens=1024,
            temperature=0,
            attention_mask=torch.ones_like(input_ids),
            pad_token_id=tokenizer.eos_token_id,
        )

def run_eagle_and_time(model, tokenizer, prompt):
    inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
    torch.cuda.synchronize()
    start = time.perf_counter()
    output_ids, new_token_count, _ = model.eagenerate(
        inputs["input_ids"].cuda(),
        temperature=0.0,
        max_new_tokens=1024,
        log=True,
        is_llama3=True,
    )
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    generated_ids = output_ids[0][-new_token_count:]  # Only new tokens
    output = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return output, elapsed, new_token_count

def run_baseline_and_time(model, tokenizer, prompt):
    inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
    input_len = inputs["input_ids"].shape[-1]
    torch.cuda.synchronize()
    start = time.perf_counter()
    output_ids = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        temperature=0.0,
        max_new_tokens=1024,
        pad_token_id=tokenizer.eos_token_id,
    )
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    generated_ids = output_ids[0][input_len:]  # Only new tokens
    output = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return output, elapsed, len(generated_ids)

def main():
    your_message = "Explain the principles of quantum mechanics and their applications in modern technology."
    prompt = get_prompt(your_message)

    print("Running EAGLE...")
    eagle_model = EaModel.from_pretrained(
        base_model_path="/home/merdogan/EAGLE/DeepSeek-R1-Distill-Llama-8B",
        ea_model_path="yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B",
        total_token=60,
        depth=5,
        top_k=10,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )
    eagle_model.eval()
    eagle_tokenizer = eagle_model.get_tokenizer()

    warmup_model(eagle_model, eagle_tokenizer, use_eagenerate=True)

    eagle_output, eagle_time, eagle_tokens = run_eagle_and_time(eagle_model, eagle_tokenizer, prompt)
    print("\n[EAGLE Output]:", eagle_output)
    print(f"[EAGLE Time]: {eagle_time:.3f} sec")
    print(f"[EAGLE Tokens/sec]: {eagle_tokens / eagle_time:.2f}")

    print("\nRunning Baseline...")
    base_model = AutoModelForCausalLM.from_pretrained(
        "/home/merdogan/EAGLE/DeepSeek-R1-Distill-Llama-8B",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map="auto",
    )
    base_model.eval()
    base_tokenizer = AutoTokenizer.from_pretrained("/home/merdogan/EAGLE/DeepSeek-R1-Distill-Llama-8B", use_fast=False)

    warmup_model(base_model, base_tokenizer, use_eagenerate=False)

    base_output, base_time, base_tokens = run_baseline_and_time(base_model, base_tokenizer, prompt)
    print("\n[Baseline Output]:", base_output)
    print(f"[Baseline Time]: {base_time:.3f} sec")
    print(f"[Baseline Tokens/sec]: {base_tokens / base_time:.2f}")

    #speedup = base_time / eagle_time if eagle_time > 0 else float("inf")
    #print(f"\n⚡ [EAGLE Speedup over Baseline]: {speedup:.2f}x")

if __name__ == "__main__":
    main()"""