#!/usr/bin/env python3
"""
Demo script comparing standard generation, SCSD v2, and SCSD v3.
SCSD v3 introduces cross-block draft generation with fixed draft length.
"""

import argparse
import os
import re
import time
from datetime import datetime

import torch
from transformers import AutoModel, AutoTokenizer

from model.modeling_llada import LLaDAModelLM
# from utils.generate_function import _generate_standard_verbose, generate
from utils.generate_ssd import ssd_without_cache as generate_v3
from utils.generate_ssd_cache import ssd_with_cache as generate_v3_cache


# ANSI color codes
class Colors:
    GREEN = '\033[92m'
    BLUE = '\033[94m'
    YELLOW = '\033[93m'
    CYAN = '\033[96m'
    MAGENTA = '\033[95m'
    RED = '\033[91m'
    RESET = '\033[0m'
    BOLD = '\033[1m'


# Few-shot examples from lm_eval Minerva math
FEW_SHOT_EXAMPLES = [
    {
        "problem": "Find the domain of the expression  $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}",
        "solution": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.",
    },
    {
        "problem": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$",
        "solution": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.",
    },
    {
        "problem": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?",
        "solution": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.  If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight.  Equating this to 480 pounds, we can solve for $n$:\\n\\begin{align*}\\n30n&=480\\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.",
    },
    {
        "problem": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.",
        "solution": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.",
    },
]


def doc_to_text(problem: str) -> str:
    """Format problem according to lm_eval standard"""
    return "Problem:" + "\n" + problem + "\n\n" + "Solution:"


def create_few_shot_prompt(problem: str, use_chat_template: bool = True) -> str:
    """Create few-shot prompt with examples"""
    prompt_parts = []
    
    # Add few-shot examples
    for example in FEW_SHOT_EXAMPLES:
        prompt_parts.append(doc_to_text(example["problem"]))
        prompt_parts.append(" " + example["solution"])
    
    # Add the actual problem
    prompt_parts.append(doc_to_text(problem))
    
    full_prompt = "\n\n".join(prompt_parts)
    
    if use_chat_template:
        # For instruct models, wrap in chat template
        return [{"role": "user", "content": full_prompt}]
    else:
        return full_prompt


def get_unnormalized_answer(text: str) -> str:
    """Extract answer from generated text"""
    INVALID_ANSWER = "[invalidanswer]"
    
    # Look for "Final Answer: The final answer is..." pattern
    match = re.search(
        r"Final Answer: The final answer is(.*?)\. I hope it is correct\.",
        text,
        re.DOTALL | re.IGNORECASE
    )
    if match:
        return match.group(1).strip()
    else:
        return INVALID_ANSWER


def last_boxed_only(text: str) -> str:
    """Extract the last \\boxed{} content"""
    idx = text.rfind("\\boxed")
    if idx < 0:
        return None
    
    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    
    while i < len(text):
        if text[i] == "{":
            num_left_braces_open += 1
        if text[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    
    if right_brace_idx is None:
        return None
    else:
        return text[idx:right_brace_idx + 1]


def remove_boxed(s: str) -> str:
    """Remove \\boxed{} wrapper"""
    if "\\boxed " in s:
        left = "\\boxed "
        return s[len(left):]
    
    left = "\\boxed{"
    if s.startswith(left) and s.endswith("}"):
        return s[len(left):-1]
    
    return s


def normalize_final_answer(final_answer: str) -> str:
    """Normalize a final answer to a quantitative reasoning question."""
    # Substitutions from lm_eval
    SUBSTITUTIONS = [
        ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""),
        (r"\ ", ""), (" ", ""), ("mbox", "text"),
        (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"),
    ]
    
    REMOVED_EXPRESSIONS = [
        "square", "ways", "integers", "dollars", "mph", "inches", "ft",
        "hours", "km", "units", "\\ldots", "sue", "points", "feet",
        "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds",
        "meters", "meals", "edges", "students", "childrentickets", "multiples",
        "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3",
        "\\text{\n}", "\\text{}", r"\\mathrm{th}", r"^\circ", r"^{\circ}",
        r"\;", r",\!", "{,}", '"', "\\dots",
    ]
    
    final_answer = final_answer.split("=")[-1]
    
    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")
    
    # Extract answer that is in LaTeX math
    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*?)(\})", "\\2", final_answer)
    
    # Normalize shorthand TeX
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")
    
    # Normalize 100,000 -> 100000
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")
    
    return final_answer


def print_mode_description(mode):
    """Print detailed description of each mode"""
    if mode == "standard":
        print(f"{Colors.BLUE}Standard Mode:{Colors.RESET}")
        print("  • Block-based generation with fixed step count")
        print("  • Each block uses allocated steps uniformly")
        print("  • Simple and predictable behavior")
    elif mode == "scsd_v2":
        print(f"{Colors.GREEN}SCSD v2 (Self Cascaded Speculative Decoding):{Colors.RESET}")
        print("  • ✓ Guaranteed progress (≥1 token/iteration)")
        print("  • ✓ O(2n) verification complexity")
        print("  • ✓ Draft maintenance with confidence tracking")
        print("  • ✓ Adaptive switching to standard decode")
    elif mode == "scsd_v3":
        print(f"{Colors.MAGENTA}SCSD v3 (Cross-Block Speculative Decoding):{Colors.RESET}")
        print("  • ✓ Fixed draft length k (default 4)")
        print("  • ✓ Cross-block draft generation")
        print("  • ✓ Block-aware priority verification")
        print("  • ✓ Extensible multi-candidate design")
        print("  • ✓ Optimized for short effective drafts")
        print("  • ✓ Support for multiple tree strategies (greedy, greedy_with_alternatives)")


def parse_args():
    parser = argparse.ArgumentParser(description="Standard vs SCSD v2 vs SCSD v3 Comparison")
    
    # Generation parameters
    parser.add_argument("--mode", type=str, default="standard", 
                        choices=["standard", "scsd_v2", "scsd_v3"],
                        help="Generation mode to use")
    parser.add_argument("--steps", type=int, default=256, help="Number of generation steps (for standard)")
    parser.add_argument("--gen_length", type=int, default=128, help="Generation length in tokens")
    parser.add_argument("--block_length", type=int, default=8, help="Block length for semi-AR mode")
    parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for generation")
    parser.add_argument("--cfg_scale", type=float, default=0.0, help="Classifier-free guidance scale")
    parser.add_argument("--cache_strategy", type=str, default="fast_dllm", help="Which cache strategy to use None/fast_dllm")
    
    # v3 specific parameters
    parser.add_argument("--draft_length", type=int, default=4, help="Fixed draft length for v3")
    parser.add_argument("--tree_strategy", type=str, default="greedy", 
                        choices=["greedy", "greedy_with_alternatives"], 
                        help="Tree building strategy for v3")
    parser.add_argument("--num_alternatives", type=int, default=1, 
                        help="Number of alternatives per position (for greedy_with_alternatives strategy)")
    
    # Model parameters
    parser.add_argument("--model_path", type=str, 
                        default="/data_disk/jza/models/LLaDA-8B-Instruct",
                        help="Path to the model")
    parser.add_argument("--device", type=str, default=None, help="Device to run on")
    
    # Output control
    parser.add_argument("--verbose", action="store_true", help="Print detailed generation progress")
    parser.add_argument("--show_prompt", action="store_true", help="Show full input prompt")
    parser.add_argument("--compare", action="store_true", help="Compare all modes")
    parser.add_argument("--compare_v2_v3", action="store_true", help="Compare only v2 and v3")
    parser.add_argument("--debug", action="store_true", help="Enable debug mode with token analysis")
    
    return parser.parse_args()


def run_generation(mode, model, tokenizer, input_ids, attention_mask, args, warmup=True):
    """Run generation with specified mode and return results"""
    
    print(f"\n{Colors.BOLD}Running {mode.upper()} mode...{Colors.RESET}")
    print_mode_description(mode)
    
    # Warmup run to ensure fair comparison
    if warmup:
        print(f"  Warming up...")
        if mode == "scsd_v3":
            if args.cache_strategy == "None":
                result = generate_v3(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    model=model,
                    gen_length=min(8, args.gen_length),
                    block_length=args.block_length,
                    temperature=args.temperature,
                    cfg_scale=args.cfg_scale,
                    mask_id=126336,
                    draft_length=args.draft_length,
                    tree_strategy=args.tree_strategy,
                    num_alternatives=args.num_alternatives,
                    verbose=False
                )
            else:
                result = generate_v3_cache(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    model=model,
                    gen_length=min(8, args.gen_length),
                    block_length=args.block_length,
                    temperature=args.temperature,
                    cfg_scale=args.cfg_scale,
                    mask_id=126336,
                    draft_length=args.draft_length,
                    tree_strategy=args.tree_strategy,
                    num_alternatives=args.num_alternatives,
                    verbose=False
                )
            # Handle tuple return
            if isinstance(result, tuple):
                _ = result[0]
            else:
                _ = result

        else:
            # Use verbose version for warmup if verbose is enabled
            pass
            # if args.verbose:
            #     _ = _generate_standard_verbose(
            #         input_ids=input_ids,
            #         attention_mask=attention_mask,
            #         model=model,
            #         steps=min(8, args.steps),
            #         gen_length=min(8, args.gen_length),
            #         block_length=min(8, args.block_length),
            #         temperature=args.temperature,
            #         cfg_scale=args.cfg_scale,
            #         remasking="low_confidence",
            #         mask_id=126336,
            #         verbose=False  # Don't print during warmup
            #     )
            # else:
            #     _ = generate(
            #         input_ids=input_ids,
            #         attention_mask=attention_mask,
            #         model=model,
            #         steps=min(8, args.steps),
            #         gen_length=min(8, args.gen_length),
            #         block_length=min(8, args.block_length),
            #         temperature=args.temperature,
            #         cfg_scale=args.cfg_scale,
            #         remasking="low_confidence",
            #     mode=mode,
            #     verbose=False
            # )
    
    start_time = time.time()
    
    if mode == "scsd_v3":
        # Use v3 generation function
        if args.cache_strategy == "None":
            result = generate_v3(
                input_ids=input_ids,
                attention_mask=attention_mask,
                model=model,
                gen_length=args.gen_length,
                block_length=args.block_length,
                temperature=args.temperature,
                cfg_scale=args.cfg_scale,
                mask_id=126336,
                draft_length=args.draft_length,
                tree_strategy=args.tree_strategy,
                num_alternatives=args.num_alternatives,
                verbose=args.verbose
            )
        else:
            result = generate_v3_cache(
                input_ids=input_ids,
                attention_mask=attention_mask,
                model=model,
                gen_length=args.gen_length,
                block_length=args.block_length,
                temperature=args.temperature,
                cfg_scale=args.cfg_scale,
                mask_id=126336,
                draft_length=args.draft_length,
                tree_strategy=args.tree_strategy,
                num_alternatives=args.num_alternatives,
                verbose=args.verbose
            )
        # Check if result is a tuple (tokens, forward_passes)
        if isinstance(result, tuple):
            generation_ids, forward_passes = result
            if args.verbose:
                print(f"Forward passes used: {forward_passes}")
        else:
            generation_ids = result
    elif mode == "scsd_v2":
        pass
    #     # Use v2 generation function
    #     generation_ids = generate_v2(
    #         input_ids=input_ids,
    #         attention_mask=attention_mask,
    #         model=model,
    #         steps=args.steps,
    #         gen_length=args.gen_length,
    #         block_length=args.block_length,
    #         temperature=args.temperature,
    #         cfg_scale=args.cfg_scale,
    #         mask_id=126336,
    #         verbose=args.verbose
    #     )
    # else:
    #     # Use verbose version if verbose flag is set
    #     if args.verbose:
    #         generation_ids = _generate_standard_verbose(
    #             input_ids=input_ids,
    #             attention_mask=attention_mask,
    #             model=model,
    #             steps=args.steps,
    #             gen_length=args.gen_length,
    #             block_length=args.block_length,
    #             temperature=args.temperature,
    #             cfg_scale=args.cfg_scale,
    #             remasking="low_confidence",
    #             mask_id=126336,
    #             verbose=args.verbose
    #         )
    #     else:
    #         generation_ids = generate(
    #             input_ids=input_ids,
    #             attention_mask=attention_mask,
    #             model=model,
    #             steps=args.steps,
    #             gen_length=args.gen_length,
    #             block_length=args.block_length,
    #             temperature=args.temperature,
    #             cfg_scale=args.cfg_scale,
    #             remasking="low_confidence",
    #             mode=mode,
    #             verbose=args.verbose
    #         )
    
    end_time = time.time()
    generation_time = end_time - start_time
    
    # Check for mask tokens in output
    mask_id = 126336
    has_masks = (generation_ids == mask_id).any().item()
    
    # Decode response
    full_response = tokenizer.batch_decode(generation_ids, skip_special_tokens=True)[0]
    
    # Post-process
    if "Problem:" in full_response:
        answer = full_response.split("Problem:")[0].strip()
    else:
        answer = full_response.strip()
    
    # Extract answer
    raw_answer = get_unnormalized_answer(answer)
    if raw_answer == "[invalidanswer]":
        boxed = last_boxed_only(answer)
        if boxed:
            raw_answer = remove_boxed(boxed)
    
    normalized_answer = normalize_final_answer(raw_answer)
    
    return {
        "mode": mode,
        "time": generation_time,
        "answer": answer,
        "raw_answer": raw_answer,
        "normalized_answer": normalized_answer,
        "tokens_per_second": args.gen_length / generation_time,
        "has_masks": has_masks,
        "token_ids": generation_ids[0].tolist() if args.debug else None
    }


def main():
    args = parse_args()
    
    # Configuration
    device = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
    
    # Adjust gen_length if needed (only for v2 and standard)
    if args.mode != "scsd_v3" and args.gen_length % args.block_length != 0:
        print(f"Warning: Adjusting gen_length to be divisible by block_length")
        args.gen_length = (args.gen_length // args.block_length) * args.block_length
    
    num_blocks = args.gen_length // args.block_length if args.block_length > 0 else 0
    
    # Header
    print("=" * 80)
    print(f"{Colors.BOLD}Standard vs SCSD v2 vs SCSD v3 Generation Comparison{Colors.RESET}")
    print("=" * 80)
    print(f"Device: {device}")
    print(f"Model: {args.model_path}")
    print(f"\n{Colors.YELLOW}Generation Settings:{Colors.RESET}")
    print(f"  Generation Length: {args.gen_length} tokens")
    if args.mode != "scsd_v3":
        print(f"  Block Length: {args.block_length} tokens")
        print(f"  Number of Blocks: {num_blocks}")
    if args.mode == "standard":
        print(f"  Steps: {args.steps}")
    if args.mode == "scsd_v3" or args.compare or args.compare_v2_v3:
        print(f"  Draft Length (v3): {args.draft_length} tokens")
        print(f"  Tree Strategy (v3): {args.tree_strategy}")
        if args.tree_strategy == "greedy_with_alternatives":
            print(f"  Alternatives per position: {args.num_alternatives}")
    print(f"  Temperature: {args.temperature}")
    print(f"  CFG Scale: {args.cfg_scale}")
    
    if args.compare:
        print(f"\n{Colors.MAGENTA}Comparison Mode: Will run all three algorithms{Colors.RESET}")
    elif args.compare_v2_v3:
        print(f"\n{Colors.MAGENTA}Comparison Mode: Will run v2 and v3{Colors.RESET}")
    else:
        print(f"\n{Colors.CYAN}Mode: {args.mode}{Colors.RESET}")
    print("=" * 80)
    
    # Load model
    print(f"Loading model...")
    start_load = time.time()
    
    if args.cache_strategy == "fast_dllm":
        model = (
            LLaDAModelLM.from_pretrained(
            args.model_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
            )
            .to(device)
            .eval()
        )
    else:
        model = (
        AutoModel.from_pretrained(
            args.model_path, 
            trust_remote_code=True, 
            torch_dtype=torch.bfloat16
        )
        .to(device)
        .eval()
        )
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path, 
        trust_remote_code=True
    )
    
    load_time = time.time() - start_load
    print(f"Model loaded in {Colors.GREEN}{load_time:.2f}{Colors.RESET} seconds")
    
    # Test problem
    math_problem = r"""A star has a measured parallax of $0.01^{\prime \prime}$, that is, $0.01$ arcseconds. How far away is it, in parsecs?"""
    
    # Create prompt
    prompt = create_few_shot_prompt(math_problem, use_chat_template=True)
    formatted_input = tokenizer.apply_chat_template(
        prompt, add_generation_prompt=True, tokenize=False
    )
    
    if args.show_prompt:
        print("\nInput prompt:")
        print("-" * 50)
        print(formatted_input)
        print("-" * 50)
    
    # Encode input
    input_ids = tokenizer(formatted_input)["input_ids"]
    attention_mask = tokenizer(formatted_input)["attention_mask"]
    input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
    attention_mask = torch.tensor(attention_mask).to(device).unsqueeze(0)
    
    print(f"Input sequence length: {input_ids.shape[1]} tokens")
    
    # Run generation(s)
    if args.compare:
        # Compare all modes
        modes_to_compare = ["standard", "scsd_v2", "scsd_v3"]
        results = []
        
        for mode in modes_to_compare:
            result = run_generation(mode, model, tokenizer, input_ids, attention_mask, args)
            results.append(result)
            
            print(f"\n{Colors.BOLD}Result for {mode}:{Colors.RESET}")
            print(f"Time: {Colors.YELLOW}{result['time']:.2f}s{Colors.RESET}")
            print(f"Speed: {Colors.GREEN}{result['tokens_per_second']:.2f}{Colors.RESET} tokens/s")
            print(f"Answer: {Colors.BLUE}{result['normalized_answer']}{Colors.RESET}")
            
            if result['has_masks']:
                print(f"{Colors.RED}⚠ WARNING: Output contains mask tokens!{Colors.RESET}")
            
            if args.debug and result['token_ids']:
                print(f"First 10 tokens: {result['token_ids'][:10]}")
        
        # Comparison table
        print("\n" + "=" * 80)
        print(f"{Colors.BOLD}Performance Comparison:{Colors.RESET}")
        print("-" * 80)
        print(f"{'Mode':<15} {'Time (s)':<12} {'Speed (tok/s)':<15} {'Answer':<20} {'Valid':<10}")
        print("-" * 80)
        
        best_time = min(r['time'] for r in results)
        for r in results:
            time_color = Colors.GREEN if r['time'] == best_time else Colors.RESET
            valid_color = Colors.GREEN if not r['has_masks'] else Colors.RED
            valid_text = "✓" if not r['has_masks'] else "✗ (masks)"
            
            print(f"{r['mode']:<15} {time_color}{r['time']:<12.2f}{Colors.RESET} "
                  f"{r['tokens_per_second']:<15.2f} {r['normalized_answer']:<20} "
                  f"{valid_color}{valid_text}{Colors.RESET}")
        print("=" * 80)
        
        # Check if answers match
        if len(set(r['normalized_answer'] for r in results)) > 1:
            print(f"\n{Colors.YELLOW}⚠ WARNING: Different answers generated!{Colors.RESET}")
            print("This indicates a potential issue with one of the algorithms.")
        
    elif args.compare_v2_v3:
        # Compare v2 and v3 only
        modes_to_compare = ["scsd_v2", "scsd_v3"]
        results = []
        
        for mode in modes_to_compare:
            result = run_generation(mode, model, tokenizer, input_ids, attention_mask, args)
            results.append(result)
            
            print(f"\n{Colors.BOLD}Result for {mode}:{Colors.RESET}")
            print(f"Time: {Colors.YELLOW}{result['time']:.2f}s{Colors.RESET}")
            print(f"Speed: {Colors.GREEN}{result['tokens_per_second']:.2f}{Colors.RESET} tokens/s")
            print(f"Answer: {Colors.BLUE}{result['normalized_answer']}{Colors.RESET}")
            
            if result['has_masks']:
                print(f"{Colors.RED}⚠ WARNING: Output contains mask tokens!{Colors.RESET}")
        
        # Comparison table
        print("\n" + "=" * 80)
        print(f"{Colors.BOLD}v2 vs v3 Comparison:{Colors.RESET}")
        print("-" * 80)
        print(f"{'Mode':<15} {'Time (s)':<12} {'Speed (tok/s)':<15} {'Speedup':<10}")
        print("-" * 80)
        
        v2_time = results[0]['time']
        for r in results:
            speedup = v2_time / r['time'] if r['mode'] == 'scsd_v3' else 1.0
            speedup_color = Colors.GREEN if speedup > 1.0 else Colors.RESET
            
            print(f"{r['mode']:<15} {r['time']:<12.2f} "
                  f"{r['tokens_per_second']:<15.2f} "
                  f"{speedup_color}{speedup:<10.2f}x{Colors.RESET}")
        print("=" * 80)
        
    else:
        # Single mode
        result = run_generation(args.mode, model, tokenizer, input_ids, attention_mask, args)
        
        print(f"\n{Colors.BOLD}Generated Solution:{Colors.RESET}")
        print("=" * 80)
        print(result['answer'])
        print("=" * 80)
        
        print(f"\n{Colors.BOLD}Extracted Answer:{Colors.RESET}")
        print(f"Raw: {Colors.BLUE}{result['raw_answer']}{Colors.RESET}")
        print(f"Normalized: {Colors.GREEN}{result['normalized_answer']}{Colors.RESET}")
        
        if result['has_masks']:
            print(f"\n{Colors.RED}⚠ WARNING: Output contains mask tokens!{Colors.RESET}")
            print("This indicates incomplete generation.")
        
        print(f"\n{Colors.BOLD}Performance Metrics:{Colors.RESET}")
        print(f"Generation Time: {Colors.YELLOW}{result['time']:.2f}{Colors.RESET} seconds")
        print(f"Speed: {Colors.GREEN}{result['tokens_per_second']:.2f}{Colors.RESET} tokens/second")
    
    print(f"\n{Colors.GREEN}✓ Demo completed!{Colors.RESET}")


if __name__ == "__main__":
    main()


# Example usage:
# Single mode:
#   python demo_inference_v3.py --mode standard --gen_length 256 --block_length 32
#   python demo_inference_v3.py --mode scsd_v2 --gen_length 256 --block_length 32
#   python demo_inference_v3.py --mode scsd_v3 --gen_length 256 --draft_length 4
#
# Using new tree strategy with alternatives:
#   python demo_inference_v3.py --mode scsd_v3 --tree_strategy greedy_with_alternatives --num_alternatives 3 --gen_length 128
#   python demo_inference_v3.py --mode scsd_v3 --tree_strategy greedy_with_alternatives --num_alternatives 2 --verbose --gen_length 32
#
# Compare all modes:
#   python demo_inference_v3.py --compare --gen_length 64
#
# Compare v2 and v3:
#   python demo_inference_v3.py --compare_v2_v3 --gen_length 128 --draft_length 4
#
# With verbose output:
#   python demo_inference_v3.py --mode scsd_v3 --verbose --gen_length 32
#   python demo_inference_v3.py --mode scsd_v3 --tree_strategy greedy_with_alternatives --num_alternatives 3 --verbose --gen_length 32
#
# Debug mode (shows token IDs):
#   python demo_inference_v3.py --compare_v2_v3 --debug --gen_length 16